implement new search strategy

KCaverly created

Change summary

Cargo.lock                                  | 121 ++++++++++++++++++++++
crates/semantic_index/Cargo.toml            |   2 
crates/semantic_index/src/db.rs             | 123 +++++++++++++++-------
crates/semantic_index/src/semantic_index.rs |  29 ++++-
script/evaluate_semantic_index              |   2 
5 files changed, 227 insertions(+), 50 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -570,7 +570,7 @@ dependencies = [
  "libc",
  "pin-project",
  "redox_syscall 0.2.16",
- "xattr",
+ "xattr 0.2.3",
 ]
 
 [[package]]
@@ -903,6 +903,15 @@ dependencies = [
  "wyz",
 ]
 
+[[package]]
+name = "blas-src"
+version = "0.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bb48fbaa7a0cb9d6d96c46bac6cedb16f13a10aebcef1c4e73515aaae8c9909d"
+dependencies = [
+ "openblas-src",
+]
+
 [[package]]
 name = "block"
 version = "0.1.6"
@@ -1181,6 +1190,15 @@ version = "0.1.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "a2698f953def977c68f935bb0dfa959375ad4638570e969e2f1e9f433cbf1af6"
 
+[[package]]
+name = "cblas-sys"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
+dependencies = [
+ "libc",
+]
+
 [[package]]
 name = "cc"
 version = "1.0.83"
@@ -4580,6 +4598,21 @@ dependencies = [
  "tempfile",
 ]
 
+[[package]]
+name = "ndarray"
+version = "0.15.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
+dependencies = [
+ "cblas-sys",
+ "libc",
+ "matrixmultiply",
+ "num-complex 0.4.4",
+ "num-integer",
+ "num-traits",
+ "rawpointer",
+]
+
 [[package]]
 name = "ndk"
 version = "0.7.0"
@@ -4706,7 +4739,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "b8536030f9fea7127f841b45bb6243b27255787fb4eb83958aa1ef9d2fdc0c36"
 dependencies = [
  "num-bigint 0.2.6",
- "num-complex",
+ "num-complex 0.2.4",
  "num-integer",
  "num-iter",
  "num-rational 0.2.4",
@@ -4762,6 +4795,15 @@ dependencies = [
  "num-traits",
 ]
 
+[[package]]
+name = "num-complex"
+version = "0.4.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
+dependencies = [
+ "num-traits",
+]
+
 [[package]]
 name = "num-derive"
 version = "0.3.3"
@@ -4948,6 +4990,32 @@ version = "0.3.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
 
+[[package]]
+name = "openblas-build"
+version = "0.10.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eba42c395477605f400a8d79ee0b756cfb82abe3eb5618e35fa70d3a36010a7f"
+dependencies = [
+ "anyhow",
+ "flate2",
+ "native-tls",
+ "tar",
+ "thiserror",
+ "ureq",
+ "walkdir",
+]
+
+[[package]]
+name = "openblas-src"
+version = "0.10.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "38e5d8af0b707ac2fe1574daa88b4157da73b0de3dc7c39fe3e2c0bb64070501"
+dependencies = [
+ "dirs 3.0.2",
+ "openblas-build",
+ "vcpkg",
+]
+
 [[package]]
 name = "openssl"
 version = "0.10.57"
@@ -6422,6 +6490,18 @@ dependencies = [
  "webpki 0.22.1",
 ]
 
+[[package]]
+name = "rustls-native-certs"
+version = "0.6.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00"
+dependencies = [
+ "openssl-probe",
+ "rustls-pemfile",
+ "schannel",
+ "security-framework",
+]
+
 [[package]]
 name = "rustls-pemfile"
 version = "1.0.3"
@@ -6740,6 +6820,7 @@ dependencies = [
  "ai",
  "anyhow",
  "async-trait",
+ "blas-src",
  "client",
  "collections",
  "ctor",
@@ -6751,6 +6832,7 @@ dependencies = [
  "language",
  "lazy_static",
  "log",
+ "ndarray",
  "node_runtime",
  "ordered-float",
  "parking_lot 0.11.2",
@@ -7629,6 +7711,17 @@ version = "1.0.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
 
+[[package]]
+name = "tar"
+version = "0.4.40"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb"
+dependencies = [
+ "filetime",
+ "libc",
+ "xattr 1.0.1",
+]
+
 [[package]]
 name = "target-lexicon"
 version = "0.12.11"
@@ -8703,6 +8796,21 @@ version = "0.7.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
 
+[[package]]
+name = "ureq"
+version = "2.7.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9"
+dependencies = [
+ "base64 0.21.4",
+ "flate2",
+ "log",
+ "native-tls",
+ "once_cell",
+ "rustls-native-certs",
+ "url",
+]
+
 [[package]]
 name = "url"
 version = "2.4.1"
@@ -9754,6 +9862,15 @@ dependencies = [
  "libc",
 ]
 
+[[package]]
+name = "xattr"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f4686009f71ff3e5c4dbcf1a282d0a44db3f021ba69350cd42086b3e5f1c6985"
+dependencies = [
+ "libc",
+]
+
 [[package]]
 name = "xmlparser"
 version = "0.13.5"

crates/semantic_index/Cargo.toml 🔗

@@ -39,6 +39,8 @@ rand.workspace = true
 schemars.workspace = true
 globset.workspace = true
 sha1 = "0.10.5"
+ndarray = { version = "0.15.0", features = ["blas"] }
+blas-src = { version = "0.8", features = ["openblas"] }
 
 [dev-dependencies]
 collections = { path = "../collections", features = ["test-support"] }

crates/semantic_index/src/db.rs 🔗

@@ -1,3 +1,5 @@
+extern crate blas_src;
+
 use crate::{
     parsing::{Span, SpanDigest},
     SEMANTIC_INDEX_VERSION,
@@ -7,6 +9,7 @@ use anyhow::{anyhow, Context, Result};
 use collections::HashMap;
 use futures::channel::oneshot;
 use gpui::executor;
+use ndarray::{Array1, Array2};
 use ordered_float::OrderedFloat;
 use project::{search::PathMatcher, Fs};
 use rpc::proto::Timestamp;
@@ -19,10 +22,16 @@ use std::{
     path::{Path, PathBuf},
     rc::Rc,
     sync::Arc,
-    time::SystemTime,
+    time::{Instant, SystemTime},
 };
 use util::TryFutureExt;
 
+pub fn argsort<T: Ord>(data: &[T]) -> Vec<usize> {
+    let mut indices = (0..data.len()).collect::<Vec<_>>();
+    indices.sort_by_key(|&i| &data[i]);
+    indices
+}
+
 #[derive(Debug)]
 pub struct FileRecord {
     pub id: usize,
@@ -409,23 +418,82 @@ impl VectorDatabase {
         limit: usize,
         file_ids: &[i64],
     ) -> impl Future<Output = Result<Vec<(i64, OrderedFloat<f32>)>>> {
-        let query_embedding = query_embedding.clone();
         let file_ids = file_ids.to_vec();
+        let query = query_embedding.clone().0;
+        let query = Array1::from_vec(query);
         self.transact(move |db| {
-            let mut results = Vec::<(i64, OrderedFloat<f32>)>::with_capacity(limit + 1);
-            Self::for_each_span(db, &file_ids, |id, embedding| {
-                let similarity = embedding.similarity(&query_embedding);
-                let ix = match results
-                    .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
-                {
-                    Ok(ix) => ix,
-                    Err(ix) => ix,
-                };
-                results.insert(ix, (id, similarity));
-                results.truncate(limit);
-            })?;
+            let mut query_statement = db.prepare(
+                "
+                    SELECT
+                        id, embedding
+                    FROM
+                        spans
+                    WHERE
+                        file_id IN rarray(?)
+                    ",
+            )?;
+
+            let deserialized_rows = query_statement
+                .query_map(params![ids_to_sql(&file_ids)], |row| {
+                    Ok((row.get::<_, usize>(0)?, row.get::<_, Embedding>(1)?))
+                })?
+                .filter_map(|row| row.ok())
+                .collect::<Vec<(usize, Embedding)>>();
+
+            let batch_n = 250;
+            let mut batches = Vec::new();
+            let mut batch_ids = Vec::new();
+            let mut batch_embeddings: Vec<f32> = Vec::new();
+            deserialized_rows.iter().for_each(|(id, embedding)| {
+                batch_ids.push(id);
+                batch_embeddings.extend(&embedding.0);
+                if batch_ids.len() == batch_n {
+                    let array =
+                        Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
+                    match array {
+                        Ok(array) => {
+                            batches.push((batch_ids.clone(), array));
+                        }
+                        Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
+                    }
+
+                    batch_ids = Vec::new();
+                    batch_embeddings = Vec::new();
+                }
+            });
+
+            if batch_ids.len() > 0 {
+                let array =
+                    Array2::from_shape_vec((batch_ids.len(), 1536), batch_embeddings.clone());
+                match array {
+                    Ok(array) => {
+                        batches.push((batch_ids.clone(), array));
+                    }
+                    Err(err) => log::error!("Failed to deserialize to ndarray: {:?}", err),
+                }
+            }
+
+            let mut ids: Vec<usize> = Vec::new();
+            let mut results = Vec::new();
+            for (batch_ids, array) in batches {
+                let scores = array
+                    .dot(&query.t())
+                    .to_vec()
+                    .iter()
+                    .map(|score| OrderedFloat(*score))
+                    .collect::<Vec<OrderedFloat<f32>>>();
+                results.extend(scores);
+                ids.extend(batch_ids);
+            }
 
-            anyhow::Ok(results)
+            let sorted_idx = argsort(&results);
+            let mut sorted_results = Vec::new();
+            let last_idx = limit.min(sorted_idx.len());
+            for idx in &sorted_idx[0..last_idx] {
+                sorted_results.push((ids[*idx] as i64, results[*idx]))
+            }
+
+            Ok(sorted_results)
         })
     }
 
@@ -468,31 +536,6 @@ impl VectorDatabase {
         })
     }
 
-    fn for_each_span(
-        db: &rusqlite::Connection,
-        file_ids: &[i64],
-        mut f: impl FnMut(i64, Embedding),
-    ) -> Result<()> {
-        let mut query_statement = db.prepare(
-            "
-            SELECT
-                id, embedding
-            FROM
-                spans
-            WHERE
-                file_id IN rarray(?)
-            ",
-        )?;
-
-        query_statement
-            .query_map(params![ids_to_sql(&file_ids)], |row| {
-                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
-            })?
-            .filter_map(|row| row.ok())
-            .for_each(|(id, embedding)| f(id, embedding));
-        Ok(())
-    }
-
     pub fn spans_for_ids(
         &self,
         ids: &[i64],

crates/semantic_index/src/semantic_index.rs 🔗

@@ -705,11 +705,13 @@ impl SemanticIndex {
 
         cx.spawn(|this, mut cx| async move {
             index.await?;
+            let t0 = Instant::now();
             let query = embedding_provider
                 .embed_batch(vec![query])
                 .await?
                 .pop()
                 .ok_or_else(|| anyhow!("could not embed query"))?;
+            log::trace!("Embedding Search Query: {:?}", t0.elapsed().as_millis());
 
             let search_start = Instant::now();
             let modified_buffer_results = this.update(&mut cx, |this, cx| {
@@ -787,10 +789,15 @@ impl SemanticIndex {
 
             let batch_n = cx.background().num_cpus();
             let ids_len = file_ids.clone().len();
-            let batch_size = if ids_len <= batch_n {
-                ids_len
-            } else {
-                ids_len / batch_n
+            let minimum_batch_size = 50;
+
+            let batch_size = {
+                let size = ids_len / batch_n;
+                if size < minimum_batch_size {
+                    minimum_batch_size
+                } else {
+                    size
+                }
             };
 
             let mut batch_results = Vec::new();
@@ -813,17 +820,26 @@ impl SemanticIndex {
             let batch_results = futures::future::join_all(batch_results).await;
 
             let mut results = Vec::new();
+            let mut min_similarity = None;
             for batch_result in batch_results {
                 if batch_result.is_ok() {
                     for (id, similarity) in batch_result.unwrap() {
+                        if min_similarity.map_or_else(|| false, |min_sim| min_sim > similarity) {
+                            continue;
+                        }
+
                         let ix = match results
                             .binary_search_by_key(&Reverse(similarity), |(_, s)| Reverse(*s))
                         {
                             Ok(ix) => ix,
                             Err(ix) => ix,
                         };
-                        results.insert(ix, (id, similarity));
-                        results.truncate(limit);
+
+                        if ix <= limit {
+                            min_similarity = Some(similarity);
+                            results.insert(ix, (id, similarity));
+                            results.truncate(limit);
+                        }
                     }
                 }
             }
@@ -856,7 +872,6 @@ impl SemanticIndex {
             })?;
 
             let buffers = futures::future::join_all(tasks).await;
-
             Ok(buffers
                 .into_iter()
                 .zip(ranges)

script/evaluate_semantic_index 🔗

@@ -1,3 +1,3 @@
 #!/bin/bash
 
-cargo run -p semantic_index --example eval
+RUST_LOG=semantic_index=trace cargo run -p semantic_index --example eval --release