@@ -1,4 +1,4 @@
-use std::collections::HashMap;
+use std::{collections::HashMap, path::PathBuf};
use anyhow::{anyhow, Result};
@@ -46,31 +46,50 @@ impl FromSql for Embedding {
}
}
-pub struct VectorDatabase {}
+pub struct VectorDatabase {
+ db: rusqlite::Connection,
+}
impl VectorDatabase {
- pub async fn initialize_database() -> Result<()> {
+ pub fn new() -> Result<Self> {
+ let this = Self {
+ db: rusqlite::Connection::open(VECTOR_DB_URL)?,
+ };
+ this.initialize_database()?;
+ Ok(this)
+ }
+
+ fn initialize_database(&self) -> Result<()> {
// This will create the database if it doesnt exist
- let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
// Initialize Vector Databasing Tables
- db.execute(
+ // self.db.execute(
+ // "
+ // CREATE TABLE IF NOT EXISTS projects (
+ // id INTEGER PRIMARY KEY AUTOINCREMENT,
+ // path NVARCHAR(100) NOT NULL
+ // )
+ // ",
+ // [],
+ // )?;
+
+ self.db.execute(
"CREATE TABLE IF NOT EXISTS files (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- path NVARCHAR(100) NOT NULL,
- sha1 NVARCHAR(40) NOT NULL
- )",
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ path NVARCHAR(100) NOT NULL,
+ sha1 NVARCHAR(40) NOT NULL
+ )",
[],
)?;
- db.execute(
+ self.db.execute(
"CREATE TABLE IF NOT EXISTS documents (
- id INTEGER PRIMARY KEY AUTOINCREMENT,
- file_id INTEGER NOT NULL,
- offset INTEGER NOT NULL,
- name NVARCHAR(100) NOT NULL,
- embedding BLOB NOT NULL,
- FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ file_id INTEGER NOT NULL,
+ offset INTEGER NOT NULL,
+ name NVARCHAR(100) NOT NULL,
+ embedding BLOB NOT NULL,
+ FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
)",
[],
)?;
@@ -78,23 +97,37 @@ impl VectorDatabase {
Ok(())
}
- pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> {
- // Write to files table, and return generated id.
- let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
+ // pub async fn get_or_create_project(project_path: PathBuf) -> Result<usize> {
+ // // Check if we have the project, if we do, return the ID
+ // // If we do not have the project, insert the project and return the ID
+
+ // let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
- let files_insert = db.execute(
+ // let projects_query = db.prepare(&format!(
+ // "SELECT id FROM projects WHERE path = {}",
+ // project_path.to_str().unwrap() // This is unsafe
+ // ))?;
+
+ // let project_id = db.last_insert_rowid();
+
+ // return Ok(project_id as usize);
+ // }
+
+ pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
+ // Write to files table, and return generated id.
+ let files_insert = self.db.execute(
"INSERT INTO files (path, sha1) VALUES (?1, ?2)",
params![indexed_file.path.to_str(), indexed_file.sha1],
)?;
- let inserted_id = db.last_insert_rowid();
+ let inserted_id = self.db.last_insert_rowid();
// Currently inserting at approximately 3400 documents a second
// I imagine we can speed this up with a bulk insert of some kind.
for document in indexed_file.documents {
let embedding_blob = bincode::serialize(&document.embedding)?;
- db.execute(
+ self.db.execute(
"INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
params![
inserted_id,
@@ -109,70 +142,42 @@ impl VectorDatabase {
}
pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
- let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
-
- fn query(db: Connection) -> rusqlite::Result<Vec<FileRecord>> {
- let mut query_statement = db.prepare("SELECT id, path, sha1 FROM files")?;
- let result_iter = query_statement.query_map([], |row| {
- Ok(FileRecord {
- id: row.get(0)?,
- path: row.get(1)?,
- sha1: row.get(2)?,
- })
- })?;
-
- let mut results = vec![];
- for result in result_iter {
- results.push(result?);
- }
-
- return Ok(results);
- }
+ let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?;
+ let result_iter = query_statement.query_map([], |row| {
+ Ok(FileRecord {
+ id: row.get(0)?,
+ path: row.get(1)?,
+ sha1: row.get(2)?,
+ })
+ })?;
let mut pages: HashMap<usize, FileRecord> = HashMap::new();
- let result_iter = query(db);
- if result_iter.is_ok() {
- for result in result_iter.unwrap() {
- pages.insert(result.id, result);
- }
+ for result in result_iter {
+ let result = result?;
+ pages.insert(result.id, result);
}
- return Ok(pages);
+ Ok(pages)
}
pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
- // Should return a HashMap in which the key is the id, and the value is the finished document
-
- // Get Data from Database
- let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
-
- fn query(db: Connection) -> rusqlite::Result<Vec<DocumentRecord>> {
- let mut query_statement =
- db.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
- let result_iter = query_statement.query_map([], |row| {
- Ok(DocumentRecord {
- id: row.get(0)?,
- file_id: row.get(1)?,
- offset: row.get(2)?,
- name: row.get(3)?,
- embedding: row.get(4)?,
- })
- })?;
-
- let mut results = vec![];
- for result in result_iter {
- results.push(result?);
- }
-
- return Ok(results);
- }
+ let mut query_statement = self
+ .db
+ .prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
+ let result_iter = query_statement.query_map([], |row| {
+ Ok(DocumentRecord {
+ id: row.get(0)?,
+ file_id: row.get(1)?,
+ offset: row.get(2)?,
+ name: row.get(3)?,
+ embedding: row.get(4)?,
+ })
+ })?;
let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
- let result_iter = query(db);
- if result_iter.is_ok() {
- for result in result_iter.unwrap() {
- documents.insert(result.id, result);
- }
+ for result in result_iter {
+ let result = result?;
+ documents.insert(result.id, result);
}
return Ok(documents);
@@ -19,8 +19,8 @@ pub struct BruteForceSearch {
}
impl BruteForceSearch {
- pub fn load() -> Result<Self> {
- let db = VectorDatabase {};
+ pub fn load(db: &VectorDatabase) -> Result<Self> {
+ // let db = VectorDatabase {};
let documents = db.get_documents()?;
let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
let mut document_ids = vec![];
@@ -47,39 +47,36 @@ impl VectorSearch for BruteForceSearch {
async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
let target = Array1::from_vec(vec.to_owned());
- let distances = self.candidate_array.dot(&target);
+ let similarities = self.candidate_array.dot(&target);
- let distances = distances.to_vec();
+ let similarities = similarities.to_vec();
// construct a tuple vector from the floats, the tuple being (index,float)
- let mut with_indices = distances
- .clone()
- .into_iter()
+ let mut with_indices = similarities
+ .iter()
+ .copied()
.enumerate()
- .map(|(index, value)| (index, value))
+ .map(|(index, value)| (self.document_ids[index], value))
.collect::<Vec<(usize, f32)>>();
// sort the tuple vector by float
- with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) {
- (true, true) => Ordering::Equal,
- (true, false) => Ordering::Greater,
- (false, true) => Ordering::Less,
- (false, false) => a.1.partial_cmp(&b.1).unwrap(),
- });
+ with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
+ with_indices.truncate(limit);
+ with_indices
- // extract the sorted indices from the sorted tuple vector
- let stored_indices = with_indices
- .into_iter()
- .map(|(index, value)| index)
- .collect::<Vec<usize>>();
+ // // extract the sorted indices from the sorted tuple vector
+ // let stored_indices = with_indices
+ // .into_iter()
+ // .map(|(index, value)| index)
+ // .collect::<Vec<>>();
- let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
+ // let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
- let mut results = vec![];
- for idx in sorted_indices[0..limit].to_vec() {
- results.push((self.document_ids[idx], 1.0 - distances[idx]));
- }
+ // let mut results = vec![];
+ // for idx in sorted_indices[0..limit].to_vec() {
+ // results.push((self.document_ids[idx], 1.0 - similarities[idx]));
+ // }
- return results;
+ // return results;
}
}
@@ -1,5 +1,6 @@
mod db;
mod embedding;
+mod parsing;
mod search;
use anyhow::{anyhow, Result};
@@ -7,11 +8,13 @@ use db::VectorDatabase;
use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
use gpui::{AppContext, Entity, ModelContext, ModelHandle};
use language::LanguageRegistry;
+use parsing::Document;
use project::{Fs, Project};
+use search::{BruteForceSearch, VectorSearch};
use smol::channel;
use std::{path::PathBuf, sync::Arc, time::Instant};
use tree_sitter::{Parser, QueryCursor};
-use util::{http::HttpClient, ResultExt};
+use util::{http::HttpClient, ResultExt, TryFutureExt};
use workspace::WorkspaceCreated;
pub fn init(
@@ -39,13 +42,6 @@ pub fn init(
.detach();
}
-#[derive(Debug)]
-pub struct Document {
- pub offset: usize,
- pub name: String,
- pub embedding: Vec<f32>,
-}
-
#[derive(Debug)]
pub struct IndexedFile {
path: PathBuf,
@@ -180,18 +176,54 @@ impl VectorStore {
.detach();
cx.background()
- .spawn(async move {
+ .spawn({
+ let client = client.clone();
+ async move {
// Initialize Database, creates database and tables if not exists
- VectorDatabase::initialize_database().await.log_err();
+ let db = VectorDatabase::new()?;
while let Ok(indexed_file) = indexed_files_rx.recv().await {
- VectorDatabase::insert_file(indexed_file).await.log_err();
+ db.insert_file(indexed_file).log_err();
+ }
+
+ // ALL OF THE BELOW IS FOR TESTING,
+ // This should be removed as we find and appropriate place for evaluate our search.
+
+ let embedding_provider = OpenAIEmbeddings{ client };
+ let queries = vec![
+ "compute embeddings for all of the symbols in the codebase, and write them to a database",
+ "compute an outline view of all of the symbols in a buffer",
+ "scan a directory on the file system and load all of its children into an in-memory snapshot",
+ ];
+ let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
+
+ let t2 = Instant::now();
+ let documents = db.get_documents().unwrap();
+ let files = db.get_files().unwrap();
+ println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
+
+ let t1 = Instant::now();
+ let mut bfs = BruteForceSearch::load(&db).unwrap();
+ println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
+ for (idx, embed) in embeddings.into_iter().enumerate() {
+ let t0 = Instant::now();
+ println!("\nQuery: {:?}", queries[idx]);
+ let results = bfs.top_k_search(&embed, 5).await;
+ println!("Search Elapsed: {}", t0.elapsed().as_millis());
+ for (id, distance) in results {
+ println!("");
+ println!(" distance: {:?}", distance);
+ println!(" document: {:?}", documents[&id].name);
+ println!(" path: {:?}", files[&documents[&id].file_id].path);
+ }
+
}
anyhow::Ok(())
- })
+ }}.log_err())
.detach();
let provider = DummyEmbeddings {};
+ // let provider = OpenAIEmbeddings { client };
cx.background()
.scoped(|scope| {