Detailed changes
@@ -1768,9 +1768,9 @@ dependencies = [
[[package]]
name = "cxx"
-version = "1.0.94"
+version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f61f1b6389c3fe1c316bf8a4dccc90a38208354b330925bce1f74a6c4756eb93"
+checksum = "e88abab2f5abbe4c56e8f1fb431b784d710b709888f35755a160e62e33fe38e8"
dependencies = [
"cc",
"cxxbridge-flags",
@@ -1795,15 +1795,15 @@ dependencies = [
[[package]]
name = "cxxbridge-flags"
-version = "1.0.94"
+version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7944172ae7e4068c533afbb984114a56c46e9ccddda550499caa222902c7f7bb"
+checksum = "8d3816ed957c008ccd4728485511e3d9aaf7db419aa321e3d2c5a2f3411e36c8"
[[package]]
name = "cxxbridge-macro"
-version = "1.0.94"
+version = "1.0.97"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2345488264226bf682893e25de0769f3360aac9957980ec49361b083ddaa5bc5"
+checksum = "a26acccf6f445af85ea056362561a24ef56cdc15fcc685f03aec50b9c702cb6d"
dependencies = [
"proc-macro2",
"quote",
@@ -7913,6 +7913,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
+ "bincode",
"futures 0.3.28",
"gpui",
"isahc",
@@ -17,7 +17,7 @@ util = { path = "../util" }
anyhow.workspace = true
futures.workspace = true
smol.workspace = true
-rusqlite = "0.27.0"
+rusqlite = { version = "0.27.0", features=["blob"] }
isahc.workspace = true
log.workspace = true
tree-sitter.workspace = true
@@ -25,6 +25,7 @@ lazy_static.workspace = true
serde.workspace = true
serde_json.workspace = true
async-trait.workspace = true
+bincode = "1.3.3"
[dev-dependencies]
gpui = { path = "../gpui", features = ["test-support"] }
@@ -1,13 +1,44 @@
-use anyhow::Result;
-use rusqlite::params;
+use std::collections::HashMap;
-use crate::IndexedFile;
+use anyhow::{anyhow, Result};
+
+use rusqlite::{
+ params,
+ types::{FromSql, FromSqlResult, ValueRef},
+ Connection,
+};
+use util::ResultExt;
+
+use crate::{Document, IndexedFile};
// This is saving to a local database store within the users dev zed path
// Where do we want this to sit?
// Assuming near where the workspace DB sits.
const VECTOR_DB_URL: &str = "embeddings_db";
+// Note this is not an appropriate document
+#[derive(Debug)]
+pub struct DocumentRecord {
+ id: usize,
+ offset: usize,
+ name: String,
+ embedding: Embedding,
+}
+
+#[derive(Debug)]
+struct Embedding(Vec<f32>);
+
+impl FromSql for Embedding {
+ fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+ let bytes = value.as_blob()?;
+ let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
+ if embedding.is_err() {
+ return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
+ }
+ return Ok(Embedding(embedding.unwrap()));
+ }
+}
+
pub struct VectorDatabase {}
impl VectorDatabase {
@@ -51,37 +82,66 @@ impl VectorDatabase {
let inserted_id = db.last_insert_rowid();
- // I stole this from https://stackoverflow.com/questions/71829931/how-do-i-convert-a-negative-f32-value-to-binary-string-and-back-again
- // I imagine there is a better way to serialize to/from blob
- fn get_binary_from_values(values: Vec<f32>) -> String {
- let bits: Vec<_> = values.iter().map(|v| v.to_bits().to_string()).collect();
- bits.join(";")
- }
-
- fn get_values_from_binary(bin: &str) -> Vec<f32> {
- (0..bin.len() / 32)
- .map(|i| {
- let start = i * 32;
- let end = start + 32;
- f32::from_bits(u32::from_str_radix(&bin[start..end], 2).unwrap())
- })
- .collect()
- }
-
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
for document in indexed_file.documents {
+ let embedding_blob = bincode::serialize(&document.embedding)?;
+
db.execute(
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
params![
inserted_id,
document.offset.to_string(),
document.name,
- get_binary_from_values(document.embedding)
+ embedding_blob
],
)?;
}
Ok(())
}
+
+ pub fn get_documents(&self) -> Result<HashMap<usize, Document>> {
+ // Should return a HashMap in which the key is the id, and the value is the finished document
+
+ // Get Data from Database
+ let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
+
+ fn query(db: Connection) -> rusqlite::Result<Vec<DocumentRecord>> {
+ let mut query_statement =
+ db.prepare("SELECT id, offset, name, embedding FROM documents LIMIT 10")?;
+ let result_iter = query_statement.query_map([], |row| {
+ Ok(DocumentRecord {
+ id: row.get(0)?,
+ offset: row.get(1)?,
+ name: row.get(2)?,
+ embedding: row.get(3)?,
+ })
+ })?;
+
+ let mut results = vec![];
+ for result in result_iter {
+ results.push(result?);
+ }
+
+ return Ok(results);
+ }
+
+ let mut documents: HashMap<usize, Document> = HashMap::new();
+ let result_iter = query(db);
+ if result_iter.is_ok() {
+ for result in result_iter.unwrap() {
+ documents.insert(
+ result.id,
+ Document {
+ offset: result.offset,
+ name: result.name,
+ embedding: result.embedding.0,
+ },
+ );
+ }
+ }
+
+ return Ok(documents);
+ }
}
@@ -13,6 +13,7 @@ lazy_static! {
static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
}
+#[derive(Clone)]
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
}
@@ -54,7 +55,7 @@ impl EmbeddingProvider for DummyEmbeddings {
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
- let dummy_vec = vec![0.32 as f32; 1024];
+ let dummy_vec = vec![0.32 as f32; 1536];
return Ok(vec![dummy_vec; spans.len()]);
}
}
@@ -0,0 +1,5 @@
+trait VectorSearch {
+ // Given a query vector, and a limit to return
+ // Return a vector of id, distance tuples.
+ fn top_k_search(&self, vec: &Vec<f32>) -> Vec<(usize, f32)>;
+}
@@ -1,5 +1,6 @@
mod db;
mod embedding;
+mod search;
use anyhow::{anyhow, Result};
use db::VectorDatabase;
@@ -39,10 +40,10 @@ pub fn init(
}
#[derive(Debug)]
-struct Document {
- offset: usize,
- name: String,
- embedding: Vec<f32>,
+pub struct Document {
+ pub offset: usize,
+ pub name: String,
+ pub embedding: Vec<f32>,
}
#[derive(Debug)]
@@ -185,14 +186,13 @@ impl VectorStore {
while let Ok(indexed_file) = indexed_files_rx.recv().await {
VectorDatabase::insert_file(indexed_file).await.log_err();
}
+
+ anyhow::Ok(())
})
.detach();
- // let provider = OpenAIEmbeddings { client };
let provider = DummyEmbeddings {};
- let t0 = Instant::now();
-
cx.background()
.scoped(|scope| {
for _ in 0..cx.background().num_cpus() {
@@ -218,9 +218,6 @@ impl VectorStore {
}
})
.await;
-
- let duration = t0.elapsed();
- log::info!("indexed project in {duration:?}");
})
.detach();
}