added proper blob serialization for embeddings and vector search trait

KCaverly created

Change summary

Cargo.lock                              |  13 +-
crates/vector_store/Cargo.toml          |   3 
crates/vector_store/src/db.rs           | 102 +++++++++++++++++++++-----
crates/vector_store/src/embedding.rs    |   3 
crates/vector_store/src/search.rs       |   5 +
crates/vector_store/src/vector_store.rs |  17 +--
6 files changed, 104 insertions(+), 39 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -1768,9 +1768,9 @@ dependencies = [
 
 [[package]]
 name = "cxx"
-version = "1.0.94"
+version = "1.0.97"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f61f1b6389c3fe1c316bf8a4dccc90a38208354b330925bce1f74a6c4756eb93"
+checksum = "e88abab2f5abbe4c56e8f1fb431b784d710b709888f35755a160e62e33fe38e8"
 dependencies = [
  "cc",
  "cxxbridge-flags",
@@ -1795,15 +1795,15 @@ dependencies = [
 
 [[package]]
 name = "cxxbridge-flags"
-version = "1.0.94"
+version = "1.0.97"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7944172ae7e4068c533afbb984114a56c46e9ccddda550499caa222902c7f7bb"
+checksum = "8d3816ed957c008ccd4728485511e3d9aaf7db419aa321e3d2c5a2f3411e36c8"
 
 [[package]]
 name = "cxxbridge-macro"
-version = "1.0.94"
+version = "1.0.97"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2345488264226bf682893e25de0769f3360aac9957980ec49361b083ddaa5bc5"
+checksum = "a26acccf6f445af85ea056362561a24ef56cdc15fcc685f03aec50b9c702cb6d"
 dependencies = [
  "proc-macro2",
  "quote",
@@ -7913,6 +7913,7 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "async-trait",
+ "bincode",
  "futures 0.3.28",
  "gpui",
  "isahc",

crates/vector_store/Cargo.toml 🔗

@@ -17,7 +17,7 @@ util = { path = "../util" }
 anyhow.workspace = true
 futures.workspace = true
 smol.workspace = true
-rusqlite = "0.27.0"
+rusqlite = { version = "0.27.0", features=["blob"] }
 isahc.workspace = true
 log.workspace = true
 tree-sitter.workspace = true
@@ -25,6 +25,7 @@ lazy_static.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 async-trait.workspace = true
+bincode = "1.3.3"
 
 [dev-dependencies]
 gpui = { path = "../gpui", features = ["test-support"] }

crates/vector_store/src/db.rs 🔗

@@ -1,13 +1,44 @@
-use anyhow::Result;
-use rusqlite::params;
+use std::collections::HashMap;
 
-use crate::IndexedFile;
+use anyhow::{anyhow, Result};
+
+use rusqlite::{
+    params,
+    types::{FromSql, FromSqlResult, ValueRef},
+    Connection,
+};
+use util::ResultExt;
+
+use crate::{Document, IndexedFile};
 
 // This is saving to a local database store within the users dev zed path
 // Where do we want this to sit?
 // Assuming near where the workspace DB sits.
 const VECTOR_DB_URL: &str = "embeddings_db";
 
+// Note this is not an appropriate document
+#[derive(Debug)]
+pub struct DocumentRecord {
+    id: usize,
+    offset: usize,
+    name: String,
+    embedding: Embedding,
+}
+
+#[derive(Debug)]
+struct Embedding(Vec<f32>);
+
+impl FromSql for Embedding {
+    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+        let bytes = value.as_blob()?;
+        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
+        if embedding.is_err() {
+            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
+        }
+        return Ok(Embedding(embedding.unwrap()));
+    }
+}
+
 pub struct VectorDatabase {}
 
 impl VectorDatabase {
@@ -51,37 +82,66 @@ impl VectorDatabase {
 
         let inserted_id = db.last_insert_rowid();
 
-        // I stole this from https://stackoverflow.com/questions/71829931/how-do-i-convert-a-negative-f32-value-to-binary-string-and-back-again
-        // I imagine there is a better way to serialize to/from blob
-        fn get_binary_from_values(values: Vec<f32>) -> String {
-            let bits: Vec<_> = values.iter().map(|v| v.to_bits().to_string()).collect();
-            bits.join(";")
-        }
-
-        fn get_values_from_binary(bin: &str) -> Vec<f32> {
-            (0..bin.len() / 32)
-                .map(|i| {
-                    let start = i * 32;
-                    let end = start + 32;
-                    f32::from_bits(u32::from_str_radix(&bin[start..end], 2).unwrap())
-                })
-                .collect()
-        }
-
         // Currently inserting at approximately 3400 documents a second
         // I imagine we can speed this up with a bulk insert of some kind.
         for document in indexed_file.documents {
+            let embedding_blob = bincode::serialize(&document.embedding)?;
+
             db.execute(
                 "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
                 params![
                     inserted_id,
                     document.offset.to_string(),
                     document.name,
-                    get_binary_from_values(document.embedding)
+                    embedding_blob
                 ],
             )?;
         }
 
         Ok(())
     }
+
+    pub fn get_documents(&self) -> Result<HashMap<usize, Document>> {
+        // Should return a HashMap in which the key is the id, and the value is the finished document
+
+        // Get Data from Database
+        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
+
+        fn query(db: Connection) -> rusqlite::Result<Vec<DocumentRecord>> {
+            let mut query_statement =
+                db.prepare("SELECT id, offset, name, embedding FROM documents LIMIT 10")?;
+            let result_iter = query_statement.query_map([], |row| {
+                Ok(DocumentRecord {
+                    id: row.get(0)?,
+                    offset: row.get(1)?,
+                    name: row.get(2)?,
+                    embedding: row.get(3)?,
+                })
+            })?;
+
+            let mut results = vec![];
+            for result in result_iter {
+                results.push(result?);
+            }
+
+            return Ok(results);
+        }
+
+        let mut documents: HashMap<usize, Document> = HashMap::new();
+        let result_iter = query(db);
+        if result_iter.is_ok() {
+            for result in result_iter.unwrap() {
+                documents.insert(
+                    result.id,
+                    Document {
+                        offset: result.offset,
+                        name: result.name,
+                        embedding: result.embedding.0,
+                    },
+                );
+            }
+        }
+
+        return Ok(documents);
+    }
 }

crates/vector_store/src/embedding.rs 🔗

@@ -13,6 +13,7 @@ lazy_static! {
     static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
 }
 
+#[derive(Clone)]
 pub struct OpenAIEmbeddings {
     pub client: Arc<dyn HttpClient>,
 }
@@ -54,7 +55,7 @@ impl EmbeddingProvider for DummyEmbeddings {
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
         // 1024 is the OpenAI Embeddings size for ada models.
         // the model we will likely be starting with.
-        let dummy_vec = vec![0.32 as f32; 1024];
+        let dummy_vec = vec![0.32 as f32; 1536];
         return Ok(vec![dummy_vec; spans.len()]);
     }
 }

crates/vector_store/src/search.rs 🔗

@@ -0,0 +1,5 @@
+trait VectorSearch {
+    // Given a query vector, and a limit to return
+    // Return a vector of id, distance tuples.
+    fn top_k_search(&self, vec: &Vec<f32>) -> Vec<(usize, f32)>;
+}

crates/vector_store/src/vector_store.rs 🔗

@@ -1,5 +1,6 @@
 mod db;
 mod embedding;
+mod search;
 
 use anyhow::{anyhow, Result};
 use db::VectorDatabase;
@@ -39,10 +40,10 @@ pub fn init(
 }
 
 #[derive(Debug)]
-struct Document {
-    offset: usize,
-    name: String,
-    embedding: Vec<f32>,
+pub struct Document {
+    pub offset: usize,
+    pub name: String,
+    pub embedding: Vec<f32>,
 }
 
 #[derive(Debug)]
@@ -185,14 +186,13 @@ impl VectorStore {
                     while let Ok(indexed_file) = indexed_files_rx.recv().await {
                         VectorDatabase::insert_file(indexed_file).await.log_err();
                     }
+
+                    anyhow::Ok(())
                 })
                 .detach();
 
-            // let provider = OpenAIEmbeddings { client };
             let provider = DummyEmbeddings {};
 
-            let t0 = Instant::now();
-
             cx.background()
                 .scoped(|scope| {
                     for _ in 0..cx.background().num_cpus() {
@@ -218,9 +218,6 @@ impl VectorStore {
                     }
                 })
                 .await;
-
-            let duration = t0.elapsed();
-            log::info!("indexed project in {duration:?}");
         })
         .detach();
     }