added brute force search and VectorSearch trait

KCaverly created

Change summary

Cargo.lock                        | 39 +++++++++++++++
crates/vector_store/Cargo.toml    |  1 
crates/vector_store/src/search.rs | 84 ++++++++++++++++++++++++++++++++
3 files changed, 122 insertions(+), 2 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3837,6 +3837,16 @@ version = "0.5.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
 
+[[package]]
+name = "matrixmultiply"
+version = "0.3.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "090126dc04f95dc0d1c1c91f61bdd474b3930ca064c1edc8a849da2c6cbe1e77"
+dependencies = [
+ "autocfg 1.1.0",
+ "rawpointer",
+]
+
 [[package]]
 name = "maybe-owned"
 version = "0.3.4"
@@ -4121,6 +4131,19 @@ dependencies = [
  "tempfile",
 ]
 
+[[package]]
+name = "ndarray"
+version = "0.15.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
+dependencies = [
+ "matrixmultiply",
+ "num-complex",
+ "num-integer",
+ "num-traits",
+ "rawpointer",
+]
+
 [[package]]
 name = "net2"
 version = "0.2.38"
@@ -4228,6 +4251,15 @@ dependencies = [
  "zeroize",
 ]
 
+[[package]]
+name = "num-complex"
+version = "0.4.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d"
+dependencies = [
+ "num-traits",
+]
+
 [[package]]
 name = "num-integer"
 version = "0.1.45"
@@ -5245,6 +5277,12 @@ dependencies = [
  "rand_core 0.5.1",
 ]
 
+[[package]]
+name = "rawpointer"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
+
 [[package]]
 name = "rayon"
 version = "1.7.0"
@@ -7920,6 +7958,7 @@ dependencies = [
  "language",
  "lazy_static",
  "log",
+ "ndarray",
  "project",
  "rusqlite",
  "serde",

crates/vector_store/Cargo.toml 🔗

@@ -26,6 +26,7 @@ serde.workspace = true
 serde_json.workspace = true
 async-trait.workspace = true
 bincode = "1.3.3"
+ndarray = "0.15.6"
 
 [dev-dependencies]
 gpui = { path = "../gpui", features = ["test-support"] }

crates/vector_store/src/search.rs 🔗

@@ -1,5 +1,85 @@
-trait VectorSearch {
+use std::cmp::Ordering;
+
+use async_trait::async_trait;
+use ndarray::{Array1, Array2};
+
+use crate::db::{DocumentRecord, VectorDatabase};
+use anyhow::Result;
+
+#[async_trait]
+pub trait VectorSearch {
     // Given a query vector, and a limit to return
     // Return a vector of id, distance tuples.
-    fn top_k_search(&self, vec: &Vec<f32>) -> Vec<(usize, f32)>;
+    async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)>;
+}
+
+pub struct BruteForceSearch {
+    document_ids: Vec<usize>,
+    candidate_array: ndarray::Array2<f32>,
+}
+
+impl BruteForceSearch {
+    pub fn load() -> Result<Self> {
+        let db = VectorDatabase {};
+        let documents = db.get_documents()?;
+        let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
+        let mut document_ids = vec![];
+        for i in documents.keys() {
+            document_ids.push(i.to_owned());
+        }
+
+        let mut candidate_array = Array2::<f32>::default((documents.len(), 1536));
+        for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() {
+            for (j, col) in row.iter_mut().enumerate() {
+                *col = embeddings[i].embedding.0[j];
+            }
+        }
+
+        return Ok(BruteForceSearch {
+            document_ids,
+            candidate_array,
+        });
+    }
+}
+
+#[async_trait]
+impl VectorSearch for BruteForceSearch {
+    async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
+        let target = Array1::from_vec(vec.to_owned());
+
+        let distances = self.candidate_array.dot(&target);
+
+        let distances = distances.to_vec();
+
+        // construct a tuple vector from the floats, the tuple being (index,float)
+        let mut with_indices = distances
+            .clone()
+            .into_iter()
+            .enumerate()
+            .map(|(index, value)| (index, value))
+            .collect::<Vec<(usize, f32)>>();
+
+        // sort the tuple vector by float
+        with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) {
+            (true, true) => Ordering::Equal,
+            (true, false) => Ordering::Greater,
+            (false, true) => Ordering::Less,
+            (false, false) => a.1.partial_cmp(&b.1).unwrap(),
+        });
+
+        // extract the sorted indices from the sorted tuple vector
+        let stored_indices = with_indices
+            .into_iter()
+            .map(|(index, value)| index)
+            .collect::<Vec<usize>>();
+
+        let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
+
+        let mut results = vec![];
+        for idx in sorted_indices[0..limit].to_vec() {
+            results.push((self.document_ids[idx], 1.0 - distances[idx]));
+        }
+
+        return results;
+    }
 }