implement new search strategy (#3029)

Kyle Caverly created

Augment current search strategy in semantic search, reducing search
times by ~60%

Release Notes:

- Implemented minimum batch sizes for concurrent database reads.
- Batch embedding matrix multiplication.
- Calculate matmul with ndarray

Change summary

Cargo.lock                                  |  25 ++++
crates/semantic_index/Cargo.toml            |   1 
crates/semantic_index/src/db.rs             | 130 +++++++++++++++-------
crates/semantic_index/src/semantic_index.rs |  17 ++
script/evaluate_semantic_index              |   2 
5 files changed, 128 insertions(+), 47 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4580,6 +4580,19 @@ dependencies = [
  "tempfile",
 ]
 
+[[package]]
+name = "ndarray"
+version = "0.15.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
+dependencies = [
+ "matrixmultiply",
+ "num-complex 0.4.4",
+ "num-integer",
+ "num-traits",
+ "rawpointer",
+]
+
 [[package]]
 name = "ndk"
 version = "0.7.0"
@@ -4706,7 +4719,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
 dependencies = [
  "num-bigint 0.2.6",
- "num-complex",
+ "num-complex 0.2.4",
  "num-integer",
  "num-iter",
  "num-rational 0.2.4",
@@ -4762,6 +4775,15 @@ dependencies = [
  "num-traits",
 ]
 
+[[package]]
+name = "num-complex"
+version = "0.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
+dependencies = [
+ "num-traits",
+]
+
 [[package]]
 name = "num-derive"
 version = "0.3.3"
@@ -6751,6 +6773,7 @@ dependencies = [
  "language",
  "lazy_static",
  "log",
+ "ndarray",
  "node_runtime",
  "ordered-float",
  "parking_lot 0.11.2",

crates/semantic_index/Cargo.toml 🔗

@@ -39,6 +39,7 @@ rand.workspace = true
 schemars.workspace = true
 globset.workspace = true
 sha1 = "0.10.5"
+ndarray = { version = "0.15.0" }
 
 [dev-dependencies]
 collections = { path = "../collections", features = ["test-support"] }

crates/semantic_index/src/db.rs 🔗

@@ -7,13 +7,13 @@ use anyhow::{anyhow, Context, Result};
 use collections::HashMap;
 use futures::channel::oneshot;
 use gpui::executor;
+use ndarray::{Array1, Array2};
 use ordered_float::OrderedFloat;
 use project::{search::PathMatcher, Fs};
 use rpc::proto::Timestamp;
 use rusqlite::params;
 use rusqlite::types::Value;
 use std::{
-    cmp::Reverse,
     future::Future,
     ops::Range,
     path::{Path, PathBuf},
@@ -23,6 +23,13 @@ use std::{
 };
 use util::TryFutureExt;
 
+pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
+    let mut indices = (0..data.len()).collect::<Vec<_>>();
+    indices.sort_by_key(|&i| &data[i]);
+    indices.reverse();
+    indices
+}
+
 #[derive(Debug)]
 pub struct FileRecord {
     pub id: usize,
@@ -409,23 +416,91 @@ impl VectorDatabase {
         limit: usize,
         file_ids: &[i64],
     ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
-        let query_embedding = query_embedding.clone();
         let file_ids = file_ids.to_vec();
+        let query = query_embedding.clone().0;
+        let query = Array1::from_vec(query);
         self.transact(move |db| {
-            let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
-            Self::for_each_span(db, &file_ids, |id, embedding| {
-                let similarity = embedding.similarity(&query_embedding);
-                let ix = match results
-                    .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
-                {
-                    Ok(ix) => ix,
-                    Err(ix) => ix,
-                };
-                results.insert(ix, (id, similarity));
-                results.truncate(limit);
-            })?;
+            let mut query_statement = db.prepare(
+                "
+                    SELECT
+                        id, embedding
+                    FROM
+                        spans
+                    WHERE
+                        file_id IN rarray(?)
+                    ",
+            )?;
+
+            let deserialized_rows = query_statement
+                .query_map(params![ids_to_sql(&file_ids)], |row| {
+                    Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
+                })?
+                .filter_map(|row| row.ok())
+                .collect::<Vec<(usize, Embedding)>>();
+
+            if deserialized_rows.len() == 0 {
+                return Ok(Vec::new());
+            }
+
+            // Get Length of Embeddings Returned
+            let embedding_len = deserialized_rows[0].1 .0.len();
+
+            let batch_n = 1000;
+            let mut batches = Vec::new();
+            let mut batch_ids = Vec::new();
+            let mut batch_embeddings: Vec<f32> = Vec::new();
+            deserialized_rows.iter().for_each(|(id, embedding)| {
+                batch_ids.push(id);
+                batch_embeddings.extend(&embedding.0);
+
+                if batch_ids.len() == batch_n {
+                    let embeddings = std::mem::take(&mut batch_embeddings);
+                    let ids = std::mem::take(&mut batch_ids);
+                    let array =
+                        Array2::from_shape_vec((ids.len(), embedding_len.clone()), embeddings);
+                    match array {
+                        Ok(array) => {
+                            batches.push((ids, array));
+                        }
+                        Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
+                    }
+                }
+            });
 
-            anyhow::Ok(results)
+            if batch_ids.len() > 0 {
+                let array = Array2::from_shape_vec(
+                    (batch_ids.len(), embedding_len),
+                    batch_embeddings.clone(),
+                );
+                match array {
+                    Ok(array) => {
+                        batches.push((batch_ids.clone(), array));
+                    }
+                    Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
+                }
+            }
+
+            let mut ids: Vec<usize> = Vec::new();
+            let mut results = Vec::new();
+            for (batch_ids, array) in batches {
+                let scores = array
+                    .dot(&query.t())
+                    .to_vec()
+                    .iter()
+                    .map(|score| OrderedFloat(*score))
+                    .collect::<Vec<OrderedFloat<f32>>>();
+                results.extend(scores);
+                ids.extend(batch_ids);
+            }
+
+            let sorted_idx = argsort(&results);
+            let mut sorted_results = Vec::new();
+            let last_idx = limit.min(sorted_idx.len());
+            for idx in &sorted_idx[0..last_idx] {
+                sorted_results.push((ids[*idx] as i64, results[*idx]))
+            }
+
+            Ok(sorted_results)
         })
     }
 
@@ -468,31 +543,6 @@ impl VectorDatabase {
         })
     }
 
-    fn for_each_span(
-        db: &rusqlite::Connection,
-        file_ids: &[i64],
-        mut f: impl FnMut(i64, Embedding),
-    ) -> Result<()> {
-        let mut query_statement = db.prepare(
-            "
-            SELECT
-                id, embedding
-            FROM
-                spans
-            WHERE
-                file_id IN rarray(?)
-            ",
-        )?;
-
-        query_statement
-            .query_map(params![ids_to_sql(&file_ids)], |row| {
-                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
-            })?
-            .filter_map(|row| row.ok())
-            .for_each(|(id, embedding)| f(id, embedding));
-        Ok(())
-    }
-
     pub fn spans_for_ids(
         &self,
         ids: &[i64],

crates/semantic_index/src/semantic_index.rs 🔗

@@ -705,11 +705,13 @@ impl SemanticIndex {
 
         cx.spawn(|this, mut cx| async move {
             index.await?;
+            let t0 = Instant::now();
             let query = embedding_provider
                 .embed_batch(vec![query])
                 .await?
                 .pop()
                 .ok_or_else(|| anyhow!("could not embed query"))?;
+            log::trace!("Embedding Search Query: {:?}ms", t0.elapsed().as_millis());
 
             let search_start = Instant::now();
             let modified_buffer_results = this.update(&mut cx, |this, cx| {
@@ -787,10 +789,15 @@ impl SemanticIndex {
 
             let batch_n = cx.background().num_cpus();
             let ids_len = file_ids.clone().len();
-            let batch_size = if ids_len <= batch_n {
-                ids_len
-            } else {
-                ids_len / batch_n
+            let minimum_batch_size = 50;
+
+            let batch_size = {
+                let size = ids_len / batch_n;
+                if size < minimum_batch_size {
+                    minimum_batch_size
+                } else {
+                    size
+                }
             };
 
             let mut batch_results = Vec::new();
@@ -822,6 +829,7 @@ impl SemanticIndex {
                             Ok(ix) => ix,
                             Err(ix) => ix,
                         };
+
                         results.insert(ix, (id, similarity));
                         results.truncate(limit);
                     }
@@ -856,7 +864,6 @@ impl SemanticIndex {
             })?;
 
             let buffers = futures::future::join_all(tasks).await;
-
             Ok(buffers
                 .into_iter()
                 .zip(ranges)

script/evaluate_semantic_index 🔗

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-cargo run -p semantic_index --example eval
+RUST_LOG=semantic_index=trace cargo run -p semantic_index --example eval --release