From 85e71415fea6102001c324b08c8558abea9b07f7 Mon Sep 17 00:00:00 2001 From: KCaverly Date: Wed, 28 Jun 2023 16:25:05 -0400 Subject: [PATCH] updated embedding database calls to maintain project consistency Co-authored-by: maxbrunsfeld --- crates/vector_store/src/db.rs | 44 ----------------- crates/vector_store/src/search.rs | 66 ------------------------- crates/vector_store/src/vector_store.rs | 1 - 3 files changed, 111 deletions(-) delete mode 100644 crates/vector_store/src/search.rs diff --git a/crates/vector_store/src/db.rs b/crates/vector_store/src/db.rs index 96856936fc6a9c3e3b3dd60b2a7a171642b27581..f1453141bb981282c9b00c49c80844b9003c7a8d 100644 --- a/crates/vector_store/src/db.rs +++ b/crates/vector_store/src/db.rs @@ -236,27 +236,6 @@ impl VectorDatabase { Ok(result) } - pub fn get_files(&self) -> Result> { - let mut query_statement = self - .db - .prepare("SELECT id, relative_path, sha1 FROM files")?; - let result_iter = query_statement.query_map([], |row| { - Ok(FileRecord { - id: row.get(0)?, - relative_path: row.get(1)?, - sha1: row.get(2)?, - }) - })?; - - let mut pages: HashMap = HashMap::new(); - for result in result_iter { - let result = result?; - pages.insert(result.id, result); - } - - Ok(pages) - } - pub fn for_each_document( &self, worktree_ids: &[i64], @@ -321,29 +300,6 @@ impl VectorDatabase { Ok(results) } - - pub fn get_documents(&self) -> Result> { - let mut query_statement = self - .db - .prepare("SELECT id, file_id, offset, name, embedding FROM documents")?; - let result_iter = query_statement.query_map([], |row| { - Ok(DocumentRecord { - id: row.get(0)?, - file_id: row.get(1)?, - offset: row.get(2)?, - name: row.get(3)?, - embedding: row.get(4)?, - }) - })?; - - let mut documents: HashMap = HashMap::new(); - for result in result_iter { - let result = result?; - documents.insert(result.id, result); - } - - return Ok(documents); - } } fn ids_to_sql(ids: &[i64]) -> Rc> { diff --git a/crates/vector_store/src/search.rs b/crates/vector_store/src/search.rs deleted file mode 100644 index 90a8d874da00cda9dc4c90bbe161d6eb454e1246..0000000000000000000000000000000000000000 --- a/crates/vector_store/src/search.rs +++ /dev/null @@ -1,66 +0,0 @@ -use std::{cmp::Ordering, path::PathBuf}; - -use async_trait::async_trait; -use ndarray::{Array1, Array2}; - -use crate::db::{DocumentRecord, VectorDatabase}; -use anyhow::Result; - -#[async_trait] -pub trait VectorSearch { - // Given a query vector, and a limit to return - // Return a vector of id, distance tuples. - async fn top_k_search(&mut self, vec: &Vec, limit: usize) -> Vec<(usize, f32)>; -} - -pub struct BruteForceSearch { - document_ids: Vec, - candidate_array: ndarray::Array2, -} - -impl BruteForceSearch { - pub fn load(db: &VectorDatabase) -> Result { - let documents = db.get_documents()?; - let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect(); - let mut document_ids = vec![]; - for i in documents.keys() { - document_ids.push(i.to_owned()); - } - - let mut candidate_array = Array2::::default((documents.len(), 1536)); - for (i, mut row) in candidate_array.axis_iter_mut(ndarray::Axis(0)).enumerate() { - for (j, col) in row.iter_mut().enumerate() { - *col = embeddings[i].embedding.0[j]; - } - } - - return Ok(BruteForceSearch { - document_ids, - candidate_array, - }); - } -} - -#[async_trait] -impl VectorSearch for BruteForceSearch { - async fn top_k_search(&mut self, vec: &Vec, limit: usize) -> Vec<(usize, f32)> { - let target = Array1::from_vec(vec.to_owned()); - - let similarities = self.candidate_array.dot(&target); - - let similarities = similarities.to_vec(); - - // construct a tuple vector from the floats, the tuple being (index,float) - let mut with_indices = similarities - .iter() - .copied() - .enumerate() - .map(|(index, value)| (self.document_ids[index], value)) - .collect::>(); - - // sort the tuple vector by float - with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)); - with_indices.truncate(limit); - with_indices - } -} diff --git a/crates/vector_store/src/vector_store.rs b/crates/vector_store/src/vector_store.rs index 92926b1f752adb799eaef7ae4f63dec92df553ce..a66c2d65ba63bbad387b1a63b9d0057244578787 100644 --- a/crates/vector_store/src/vector_store.rs +++ b/crates/vector_store/src/vector_store.rs @@ -1,7 +1,6 @@ mod db; mod embedding; mod modal; -mod search; #[cfg(test)] mod vector_store_tests;