diff --git a/Cargo.lock b/Cargo.lock index 3f13c75ddaac1e2a62764c95881909cec16d1e7b..309bcfa3781ae3015192ff13154a87fefd4335f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1768,9 +1768,9 @@ dependencies = [ [[package]] name = "cxx" -version = "1.0.94" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f61f1b6389c3fe1c316bf8a4dccc90a38208354b330925bce1f74a6c4756eb93" +checksum = "e88abab2f5abbe4c56e8f1fb431b784d710b709888f35755a160e62e33fe38e8" dependencies = [ "cc", "cxxbridge-flags", @@ -1795,15 +1795,15 @@ dependencies = [ [[package]] name = "cxxbridge-flags" -version = "1.0.94" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7944172ae7e4068c533afbb984114a56c46e9ccddda550499caa222902c7f7bb" +checksum = "8d3816ed957c008ccd4728485511e3d9aaf7db419aa321e3d2c5a2f3411e36c8" [[package]] name = "cxxbridge-macro" -version = "1.0.94" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2345488264226bf682893e25de0769f3360aac9957980ec49361b083ddaa5bc5" +checksum = "a26acccf6f445af85ea056362561a24ef56cdc15fcc685f03aec50b9c702cb6d" dependencies = [ "proc-macro2", "quote", @@ -7913,6 +7913,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "bincode", "futures 0.3.28", "gpui", "isahc", diff --git a/crates/vector_store/Cargo.toml b/crates/vector_store/Cargo.toml index 434f34114780426e8140c69d1b1657290fb8fb1e..6446651d5dcf0f66144f5fac97ded05b164f0a89 100644 --- a/crates/vector_store/Cargo.toml +++ b/crates/vector_store/Cargo.toml @@ -17,7 +17,7 @@ util = { path = "../util" } anyhow.workspace = true futures.workspace = true smol.workspace = true -rusqlite = "0.27.0" +rusqlite = { version = "0.27.0", features=["blob"] } isahc.workspace = true log.workspace = true tree-sitter.workspace = true @@ -25,6 +25,7 @@ lazy_static.workspace = true serde.workspace = true serde_json.workspace = true async-trait.workspace = true +bincode = "1.3.3" [dev-dependencies] gpui = { path = "../gpui", features = ["test-support"] } diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index e2b23f754898f12ed79eff5b4cd76f9ca8f3ee69..54f0292d1f70a721c07f7608dee1cd7aa43bb778 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -1,13 +1,44 @@ -use anyhow::Result; -use rusqlite::params; +use std::collections::HashMap; -use crate::IndexedFile; +use anyhow::{anyhow, Result}; + +use rusqlite::{ + params, + types::{FromSql, FromSqlResult, ValueRef}, + Connection, +}; +use util::ResultExt; + +use crate::{Document, IndexedFile}; // This is saving to a local database store within the users dev zed path // Where do we want this to sit? // Assuming near where the workspace DB sits. const VECTOR_DB_URL: &str = "embeddings_db"; +// Note this is not an appropriate document +#[derive(Debug)] +pub struct DocumentRecord { + id: usize, + offset: usize, + name: String, + embedding: Embedding, +} + +#[derive(Debug)] +struct Embedding(Vec); + +impl FromSql for Embedding { + fn column_result(value: ValueRef) -> FromSqlResult { + let bytes = value.as_blob()?; + let embedding: Result, Box> = bincode::deserialize(bytes); + if embedding.is_err() { + return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err())); + } + return Ok(Embedding(embedding.unwrap())); + } +} + pub struct VectorDatabase {} impl VectorDatabase { @@ -51,37 +82,66 @@ impl VectorDatabase { let inserted_id = db.last_insert_rowid(); - // I stole this from https://stackoverflow.com/questions/71829931/how-do-i-convert-a-negative-f32-value-to-binary-string-and-back-again - // I imagine there is a better way to serialize to/from blob - fn get_binary_from_values(values: Vec) -> String { - let bits: Vec<_> = values.iter().map(|v| v.to_bits().to_string()).collect(); - bits.join(";") - } - - fn get_values_from_binary(bin: &str) -> Vec { - (0..bin.len() / 32) - .map(|i| { - let start = i * 32; - let end = start + 32; - f32::from_bits(u32::from_str_radix(&bin[start..end], 2).unwrap()) - }) - .collect() - } - // Currently inserting at approximately 3400 documents a second // I imagine we can speed this up with a bulk insert of some kind. for document in indexed_file.documents { + let embedding_blob = bincode::serialize(&document.embedding)?; + db.execute( "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)", params![ inserted_id, document.offset.to_string(), document.name, - get_binary_from_values(document.embedding) + embedding_blob ], )?; } Ok(()) } + + pub fn get_documents(&self) -> Result> { + // Should return a HashMap in which the key is the id, and the value is the finished document + + // Get Data from Database + let db = rusqlite::Connection::open(VECTOR_DB_URL)?; + + fn query(db: Connection) -> rusqlite::Result> { + let mut query_statement = + db.prepare("SELECT id, offset, name, embedding FROM documents LIMIT 10")?; + let result_iter = query_statement.query_map([], |row| { + Ok(DocumentRecord { + id: row.get(0)?, + offset: row.get(1)?, + name: row.get(2)?, + embedding: row.get(3)?, + }) + })?; + + let mut results = vec![]; + for result in result_iter { + results.push(result?); + } + + return Ok(results); + } + + let mut documents: HashMap = HashMap::new(); + let result_iter = query(db); + if result_iter.is_ok() { + for result in result_iter.unwrap() { + documents.insert( + result.id, + Document { + offset: result.offset, + name: result.name, + embedding: result.embedding.0, + }, + ); + } + } + + return Ok(documents); + } } diff --git a/crates/vector_store/src/embedding.rs b/crates/vector_store/src/embedding.rs index 4883917d5a348e5a5c840f00f5c092a2fcdd948a..903c2451b3ddef0501e5213a3ab007714b402107 100644 --- a/crates/vector_store/src/embedding.rs +++ b/crates/vector_store/src/embedding.rs @@ -13,6 +13,7 @@ lazy_static! { static ref OPENAI_API_KEY: Option = env::var("OPENAI_API_KEY").ok(); } +#[derive(Clone)] pub struct OpenAIEmbeddings { pub client: Arc, } @@ -54,7 +55,7 @@ impl EmbeddingProvider for DummyEmbeddings { async fn embed_batch(&self, spans: Vec<&str>) -> Result>> { // 1024 is the OpenAI Embeddings size for ada models. // the model we will likely be starting with. - let dummy_vec = vec![0.32 as f32; 1024]; + let dummy_vec = vec![0.32 as f32; 1536]; return Ok(vec![dummy_vec; spans.len()]); } } diff --git a/crates/vector_store/src/search.rs b/crates/vector_store/src/search.rs new file mode 100644 index 0000000000000000000000000000000000000000..3dc72edbce9da7154ebb66d39cd2074b71817180 --- /dev/null +++ b/crates/vector_store/src/search.rs @@ -0,0 +1,5 @@ +trait VectorSearch { + // Given a query vector, and a limit to return + // Return a vector of id, distance tuples. + fn top_k_search(&self, vec: &Vec) -> Vec<(usize, f32)>; +} diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index f424346d569eea90dbe4cec01bed3757c4624abb..0b6d2928cccd76e0c2add7b4308fd28da614f387 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,5 +1,6 @@ mod db; mod embedding; +mod search; use anyhow::{anyhow, Result}; use db::VectorDatabase; @@ -39,10 +40,10 @@ pub fn init( } #[derive(Debug)] -struct Document { - offset: usize, - name: String, - embedding: Vec, +pub struct Document { + pub offset: usize, + pub name: String, + pub embedding: Vec, } #[derive(Debug)] @@ -185,14 +186,13 @@ impl VectorStore { while let Ok(indexed_file) = indexed_files_rx.recv().await { VectorDatabase::insert_file(indexed_file).await.log_err(); } + + anyhow::Ok(()) }) .detach(); - // let provider = OpenAIEmbeddings { client }; let provider = DummyEmbeddings {}; - let t0 = Instant::now(); - cx.background() .scoped(|scope| { for _ in 0..cx.background().num_cpus() { @@ -218,9 +218,6 @@ impl VectorStore { } }) .await; - - let duration = t0.elapsed(); - log::info!("indexed project in {duration:?}"); }) .detach(); }