batch search queries in the vector database

KCaverly created

Change summary

crates/search/src/project_search.rs         |  3 
crates/semantic_index/src/db.rs             | 72 +++++++++++++-------
crates/semantic_index/src/semantic_index.rs | 76 +++++++++++++++++-----
3 files changed, 106 insertions(+), 45 deletions(-)

Detailed changes

crates/search/src/project_search.rs 🔗

@@ -30,6 +30,7 @@ use std::{
     ops::{Not, Range},
     path::PathBuf,
     sync::Arc,
+    time::Instant,
 };
 use util::ResultExt as _;
 use workspace::{
@@ -192,6 +193,7 @@ impl ProjectSearch {
         exclude_files: Vec<GlobMatcher>,
         cx: &mut ModelContext<Self>,
     ) {
+        let t0 = Instant::now();
         let search = SemanticIndex::global(cx).map(|index| {
             index.update(cx, |semantic_index, cx| {
                 semantic_index.search_project(
@@ -208,6 +210,7 @@ impl ProjectSearch {
         self.match_ranges.clear();
         self.pending_search = Some(cx.spawn(|this, mut cx| async move {
             let results = search?.await.log_err()?;
+            log::trace!("semantic search elapsed: {:?}", t0.elapsed().as_millis());
 
             let (_task, mut match_ranges) = this.update(&mut cx, |this, cx| {
                 this.excerpts.update(cx, |excerpts, cx| {

crates/semantic_index/src/db.rs 🔗

@@ -267,41 +267,56 @@ impl VectorDatabase {
 
     pub fn top_k_search(
         &self,
-        worktree_ids: &[i64],
         query_embedding: &Vec<f32>,
         limit: usize,
-        include_globs: Vec<GlobMatcher>,
-        exclude_globs: Vec<GlobMatcher>,
-    ) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
+        file_ids: &[i64],
+    ) -> Result<Vec<(i64, f32)>> {
         let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
-        self.for_each_document(
-            &worktree_ids,
-            include_globs,
-            exclude_globs,
-            |id, embedding| {
-                let similarity = dot(&embedding, &query_embedding);
-                let ix = match results.binary_search_by(|(_, s)| {
-                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
-                }) {
-                    Ok(ix) => ix,
-                    Err(ix) => ix,
-                };
-                results.insert(ix, (id, similarity));
-                results.truncate(limit);
-            },
-        )?;
+        self.for_each_document(file_ids, |id, embedding| {
+            let similarity = dot(&embedding, &query_embedding);
+            let ix = match results
+                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
+            {
+                Ok(ix) => ix,
+                Err(ix) => ix,
+            };
+            results.insert(ix, (id, similarity));
+            results.truncate(limit);
+        })?;
 
-        let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
-        self.get_documents_by_ids(&ids)
+        Ok(results)
     }
 
-    fn for_each_document(
+    // pub fn top_k_search(
+    //     &self,
+    //     worktree_ids: &[i64],
+    //     query_embedding: &Vec<f32>,
+    //     limit: usize,
+    //     file_ids: Vec<i64>,
+    // ) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
+    //     let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+    //     self.for_each_document(&worktree_ids, file_ids, |id, embedding| {
+    //         let similarity = dot(&embedding, &query_embedding);
+    //         let ix = match results
+    //             .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
+    //         {
+    //             Ok(ix) => ix,
+    //             Err(ix) => ix,
+    //         };
+    //         results.insert(ix, (id, similarity));
+    //         results.truncate(limit);
+    //     })?;
+
+    //     let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
+    //     self.get_documents_by_ids(&ids)
+    // }
+
+    pub fn retrieve_included_file_ids(
         &self,
         worktree_ids: &[i64],
         include_globs: Vec<GlobMatcher>,
         exclude_globs: Vec<GlobMatcher>,
-        mut f: impl FnMut(i64, Vec<f32>),
-    ) -> Result<()> {
+    ) -> Result<Vec<i64>> {
         let mut file_query = self.db.prepare(
             "
             SELECT
@@ -315,6 +330,7 @@ impl VectorDatabase {
 
         let mut file_ids = Vec::<i64>::new();
         let mut rows = file_query.query([ids_to_sql(worktree_ids)])?;
+
         while let Some(row) = rows.next()? {
             let file_id = row.get(0)?;
             let relative_path = row.get_ref(1)?.as_str()?;
@@ -330,6 +346,10 @@ impl VectorDatabase {
             }
         }
 
+        Ok(file_ids)
+    }
+
+    fn for_each_document(&self, file_ids: &[i64], mut f: impl FnMut(i64, Vec<f32>)) -> Result<()> {
         let mut query_statement = self.db.prepare(
             "
             SELECT
@@ -350,7 +370,7 @@ impl VectorDatabase {
         Ok(())
     }
 
-    fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
+    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, Range<usize>)>> {
         let mut statement = self.db.prepare(
             "
                 SELECT

crates/semantic_index/src/semantic_index.rs 🔗

@@ -20,6 +20,7 @@ use postage::watch;
 use project::{Fs, Project, WorktreeId};
 use smol::channel;
 use std::{
+    cmp::Ordering,
     collections::HashMap,
     mem,
     ops::Range,
@@ -704,27 +705,64 @@ impl SemanticIndex {
         let database_url = self.database_url.clone();
         let fs = self.fs.clone();
         cx.spawn(|this, mut cx| async move {
-            let documents = cx
-                .background()
-                .spawn(async move {
-                    let database = VectorDatabase::new(fs, database_url).await?;
+            let database = VectorDatabase::new(fs.clone(), database_url.clone()).await?;
 
-                    let phrase_embedding = embedding_provider
-                        .embed_batch(vec![&phrase])
-                        .await?
-                        .into_iter()
-                        .next()
-                        .unwrap();
+            let phrase_embedding = embedding_provider
+                .embed_batch(vec![&phrase])
+                .await?
+                .into_iter()
+                .next()
+                .unwrap();
 
-                    database.top_k_search(
-                        &worktree_db_ids,
-                        &phrase_embedding,
-                        limit,
-                        include_globs,
-                        exclude_globs,
-                    )
-                })
-                .await?;
+            let file_ids = database.retrieve_included_file_ids(
+                &worktree_db_ids,
+                include_globs,
+                exclude_globs,
+            )?;
+
+            let batch_n = cx.background().num_cpus();
+            let batch_size = file_ids.clone().len() / batch_n;
+
+            let mut result_tasks = Vec::new();
+            for batch in file_ids.chunks(batch_size) {
+                let batch = batch.into_iter().map(|v| *v).collect::<Vec<i64>>();
+                let limit = limit.clone();
+                let fs = fs.clone();
+                let database_url = database_url.clone();
+                let phrase_embedding = phrase_embedding.clone();
+                let task = cx.background().spawn(async move {
+                    let database = VectorDatabase::new(fs, database_url).await.log_err();
+                    if database.is_none() {
+                        return Err(anyhow!("failed to acquire database connection"));
+                    } else {
+                        database
+                            .unwrap()
+                            .top_k_search(&phrase_embedding, limit, batch.as_slice())
+                    }
+                });
+                result_tasks.push(task);
+            }
+
+            let batch_results = futures::future::join_all(result_tasks).await;
+
+            let mut results = Vec::new();
+            for batch_result in batch_results {
+                if batch_result.is_ok() {
+                    for (id, similarity) in batch_result.unwrap() {
+                        let ix = match results.binary_search_by(|(_, s)| {
+                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
+                        }) {
+                            Ok(ix) => ix,
+                            Err(ix) => ix,
+                        };
+                        results.insert(ix, (id, similarity));
+                        results.truncate(limit);
+                    }
+                }
+            }
+
+            let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
+            let documents = database.get_documents_by_ids(ids.as_slice())?;
 
             let mut tasks = Vec::new();
             let mut ranges = Vec::new();