Updated database calls to share single connection, and simplified top_k_search sorting.

KCaverly and maxbrunsfeld created

Co-authored-by: maxbrunsfeld <max@zed.dev>

Change summary

crates/vector_store/src/db.rs           | 159 +++++++++++++-------------
crates/vector_store/src/embedding.rs    |  10 -
crates/vector_store/src/search.rs       |  47 +++----
crates/vector_store/src/vector_store.rs |  56 +++++++--
4 files changed, 148 insertions(+), 124 deletions(-)

Detailed changes

crates/vector_store/src/db.rs 🔗

@@ -1,4 +1,4 @@
-use std::collections::HashMap;
+use std::{collections::HashMap, path::PathBuf};
 
 use anyhow::{anyhow, Result};
 
@@ -46,31 +46,50 @@ impl FromSql for Embedding {
     }
 }
 
-pub struct VectorDatabase {}
+pub struct VectorDatabase {
+    db: rusqlite::Connection,
+}
 
 impl VectorDatabase {
-    pub async fn initialize_database() -> Result<()> {
+    pub fn new() -> Result<Self> {
+        let this = Self {
+            db: rusqlite::Connection::open(VECTOR_DB_URL)?,
+        };
+        this.initialize_database()?;
+        Ok(this)
+    }
+
+    fn initialize_database(&self) -> Result<()> {
         // This will create the database if it doesnt exist
-        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
 
         // Initialize Vector Databasing Tables
-        db.execute(
+        // self.db.execute(
+        //     "
+        //     CREATE TABLE IF NOT EXISTS projects (
+        //         id INTEGER PRIMARY KEY AUTOINCREMENT,
+        //         path NVARCHAR(100) NOT NULL
+        //     )
+        //     ",
+        //     [],
+        // )?;
+
+        self.db.execute(
             "CREATE TABLE IF NOT EXISTS files (
-        id INTEGER PRIMARY KEY AUTOINCREMENT,
-        path NVARCHAR(100) NOT NULL,
-        sha1 NVARCHAR(40) NOT NULL
-        )",
+                id INTEGER PRIMARY KEY AUTOINCREMENT,
+                path NVARCHAR(100) NOT NULL,
+                sha1 NVARCHAR(40) NOT NULL
+            )",
             [],
         )?;
 
-        db.execute(
+        self.db.execute(
             "CREATE TABLE IF NOT EXISTS documents (
-            id INTEGER PRIMARY KEY AUTOINCREMENT,
-            file_id INTEGER NOT NULL,
-            offset INTEGER NOT NULL,
-            name NVARCHAR(100) NOT NULL,
-            embedding BLOB NOT NULL,
-            FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
+                id INTEGER PRIMARY KEY AUTOINCREMENT,
+                file_id INTEGER NOT NULL,
+                offset INTEGER NOT NULL,
+                name NVARCHAR(100) NOT NULL,
+                embedding BLOB NOT NULL,
+                FOREIGN KEY(file_id) REFERENCES files(id) ON DELETE CASCADE
             )",
             [],
         )?;
@@ -78,23 +97,37 @@ impl VectorDatabase {
         Ok(())
     }
 
-    pub async fn insert_file(indexed_file: IndexedFile) -> Result<()> {
-        // Write to files table, and return generated id.
-        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
+    // pub async fn get_or_create_project(project_path: PathBuf) -> Result<usize> {
+    //     // Check if we have the project, if we do, return the ID
+    //     // If we do not have the project, insert the project and return the ID
+
+    //     let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
 
-        let files_insert = db.execute(
+    //     let projects_query = db.prepare(&format!(
+    //         "SELECT id FROM projects WHERE path = {}",
+    //         project_path.to_str().unwrap() // This is unsafe
+    //     ))?;
+
+    //     let project_id = db.last_insert_rowid();
+
+    //     return Ok(project_id as usize);
+    // }
+
+    pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
+        // Write to files table, and return generated id.
+        let files_insert = self.db.execute(
             "INSERT INTO files (path, sha1) VALUES (?1, ?2)",
             params![indexed_file.path.to_str(), indexed_file.sha1],
         )?;
 
-        let inserted_id = db.last_insert_rowid();
+        let inserted_id = self.db.last_insert_rowid();
 
         // Currently inserting at approximately 3400 documents a second
         // I imagine we can speed this up with a bulk insert of some kind.
         for document in indexed_file.documents {
             let embedding_blob = bincode::serialize(&document.embedding)?;
 
-            db.execute(
+            self.db.execute(
                 "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
                 params![
                     inserted_id,
@@ -109,70 +142,42 @@ impl VectorDatabase {
     }
 
     pub fn get_files(&self) -> Result<HashMap<usize, FileRecord>> {
-        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
-
-        fn query(db: Connection) -> rusqlite::Result<Vec<FileRecord>> {
-            let mut query_statement = db.prepare("SELECT id, path, sha1 FROM files")?;
-            let result_iter = query_statement.query_map([], |row| {
-                Ok(FileRecord {
-                    id: row.get(0)?,
-                    path: row.get(1)?,
-                    sha1: row.get(2)?,
-                })
-            })?;
-
-            let mut results = vec![];
-            for result in result_iter {
-                results.push(result?);
-            }
-
-            return Ok(results);
-        }
+        let mut query_statement = self.db.prepare("SELECT id, path, sha1 FROM files")?;
+        let result_iter = query_statement.query_map([], |row| {
+            Ok(FileRecord {
+                id: row.get(0)?,
+                path: row.get(1)?,
+                sha1: row.get(2)?,
+            })
+        })?;
 
         let mut pages: HashMap<usize, FileRecord> = HashMap::new();
-        let result_iter = query(db);
-        if result_iter.is_ok() {
-            for result in result_iter.unwrap() {
-                pages.insert(result.id, result);
-            }
+        for result in result_iter {
+            let result = result?;
+            pages.insert(result.id, result);
         }
 
-        return Ok(pages);
+        Ok(pages)
     }
 
     pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
-        // Should return a HashMap in which the key is the id, and the value is the finished document
-
-        // Get Data from Database
-        let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
-
-        fn query(db: Connection) -> rusqlite::Result<Vec<DocumentRecord>> {
-            let mut query_statement =
-                db.prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
-            let result_iter = query_statement.query_map([], |row| {
-                Ok(DocumentRecord {
-                    id: row.get(0)?,
-                    file_id: row.get(1)?,
-                    offset: row.get(2)?,
-                    name: row.get(3)?,
-                    embedding: row.get(4)?,
-                })
-            })?;
-
-            let mut results = vec![];
-            for result in result_iter {
-                results.push(result?);
-            }
-
-            return Ok(results);
-        }
+        let mut query_statement = self
+            .db
+            .prepare("SELECT id, file_id, offset, name, embedding FROM documents")?;
+        let result_iter = query_statement.query_map([], |row| {
+            Ok(DocumentRecord {
+                id: row.get(0)?,
+                file_id: row.get(1)?,
+                offset: row.get(2)?,
+                name: row.get(3)?,
+                embedding: row.get(4)?,
+            })
+        })?;
 
         let mut documents: HashMap<usize, DocumentRecord> = HashMap::new();
-        let result_iter = query(db);
-        if result_iter.is_ok() {
-            for result in result_iter.unwrap() {
-                documents.insert(result.id, result);
-            }
+        for result in result_iter {
+            let result = result?;
+            documents.insert(result.id, result);
         }
 
         return Ok(documents);

crates/vector_store/src/embedding.rs 🔗

@@ -94,16 +94,6 @@ impl EmbeddingProvider for OpenAIEmbeddings {
             response.usage.total_tokens
         );
 
-        // do we need to re-order these based on the `index` field?
-        eprintln!(
-            "indices: {:?}",
-            response
-                .data
-                .iter()
-                .map(|embedding| embedding.index)
-                .collect::<Vec<_>>()
-        );
-
         Ok(response
             .data
             .into_iter()

crates/vector_store/src/search.rs 🔗

@@ -19,8 +19,8 @@ pub struct BruteForceSearch {
 }
 
 impl BruteForceSearch {
-    pub fn load() -> Result<Self> {
-        let db = VectorDatabase {};
+    pub fn load(db: &VectorDatabase) -> Result<Self> {
+        // let db = VectorDatabase {};
         let documents = db.get_documents()?;
         let embeddings: Vec<&DocumentRecord> = documents.values().into_iter().collect();
         let mut document_ids = vec![];
@@ -47,39 +47,36 @@ impl VectorSearch for BruteForceSearch {
     async fn top_k_search(&mut self, vec: &Vec<f32>, limit: usize) -> Vec<(usize, f32)> {
         let target = Array1::from_vec(vec.to_owned());
 
-        let distances = self.candidate_array.dot(&target);
+        let similarities = self.candidate_array.dot(&target);
 
-        let distances = distances.to_vec();
+        let similarities = similarities.to_vec();
 
         // construct a tuple vector from the floats, the tuple being (index,float)
-        let mut with_indices = distances
-            .clone()
-            .into_iter()
+        let mut with_indices = similarities
+            .iter()
+            .copied()
             .enumerate()
-            .map(|(index, value)| (index, value))
+            .map(|(index, value)| (self.document_ids[index], value))
             .collect::<Vec<(usize, f32)>>();
 
         // sort the tuple vector by float
-        with_indices.sort_by(|&a, &b| match (a.1.is_nan(), b.1.is_nan()) {
-            (true, true) => Ordering::Equal,
-            (true, false) => Ordering::Greater,
-            (false, true) => Ordering::Less,
-            (false, false) => a.1.partial_cmp(&b.1).unwrap(),
-        });
+        with_indices.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
+        with_indices.truncate(limit);
+        with_indices
 
-        // extract the sorted indices from the sorted tuple vector
-        let stored_indices = with_indices
-            .into_iter()
-            .map(|(index, value)| index)
-            .collect::<Vec<usize>>();
+        // // extract the sorted indices from the sorted tuple vector
+        // let stored_indices = with_indices
+        //     .into_iter()
+        //     .map(|(index, value)| index)
+        //     .collect::<Vec<>>();
 
-        let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
+        // let sorted_indices: Vec<usize> = stored_indices.into_iter().rev().collect();
 
-        let mut results = vec![];
-        for idx in sorted_indices[0..limit].to_vec() {
-            results.push((self.document_ids[idx], 1.0 - distances[idx]));
-        }
+        // let mut results = vec![];
+        // for idx in sorted_indices[0..limit].to_vec() {
+        //     results.push((self.document_ids[idx], 1.0 - similarities[idx]));
+        // }
 
-        return results;
+        // return results;
     }
 }

crates/vector_store/src/vector_store.rs 🔗

@@ -1,5 +1,6 @@
 mod db;
 mod embedding;
+mod parsing;
 mod search;
 
 use anyhow::{anyhow, Result};
@@ -7,11 +8,13 @@ use db::VectorDatabase;
 use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
 use gpui::{AppContext, Entity, ModelContext, ModelHandle};
 use language::LanguageRegistry;
+use parsing::Document;
 use project::{Fs, Project};
+use search::{BruteForceSearch, VectorSearch};
 use smol::channel;
 use std::{path::PathBuf, sync::Arc, time::Instant};
 use tree_sitter::{Parser, QueryCursor};
-use util::{http::HttpClient, ResultExt};
+use util::{http::HttpClient, ResultExt, TryFutureExt};
 use workspace::WorkspaceCreated;
 
 pub fn init(
@@ -39,13 +42,6 @@ pub fn init(
     .detach();
 }
 
-#[derive(Debug)]
-pub struct Document {
-    pub offset: usize,
-    pub name: String,
-    pub embedding: Vec<f32>,
-}
-
 #[derive(Debug)]
 pub struct IndexedFile {
     path: PathBuf,
@@ -180,18 +176,54 @@ impl VectorStore {
                 .detach();
 
             cx.background()
-                .spawn(async move {
+                .spawn({
+                    let client = client.clone();
+                    async move {
                     // Initialize Database, creates database and tables if not exists
-                    VectorDatabase::initialize_database().await.log_err();
+                    let db = VectorDatabase::new()?;
                     while let Ok(indexed_file) = indexed_files_rx.recv().await {
-                        VectorDatabase::insert_file(indexed_file).await.log_err();
+                        db.insert_file(indexed_file).log_err();
+                    }
+
+                    // ALL OF THE BELOW IS FOR TESTING,
+                    // This should be removed as we find and appropriate place for evaluate our search.
+
+                    let embedding_provider = OpenAIEmbeddings{ client };
+                    let queries = vec![
+                        "compute embeddings for all of the symbols in the codebase, and write them to a database",
+                            "compute an outline view of all of the symbols in a buffer",
+                            "scan a directory on the file system and load all of its children into an in-memory snapshot",
+                    ];
+                    let embeddings = embedding_provider.embed_batch(queries.clone()).await?;
+
+                    let t2 = Instant::now();
+                    let documents = db.get_documents().unwrap();
+                    let files = db.get_files().unwrap();
+                    println!("Retrieving all documents from Database: {}", t2.elapsed().as_millis());
+
+                    let t1 = Instant::now();
+                    let mut bfs = BruteForceSearch::load(&db).unwrap();
+                    println!("Loading BFS to Memory: {:?}", t1.elapsed().as_millis());
+                    for (idx, embed) in embeddings.into_iter().enumerate() {
+                        let t0 = Instant::now();
+                        println!("\nQuery: {:?}", queries[idx]);
+                        let results = bfs.top_k_search(&embed, 5).await;
+                        println!("Search Elapsed: {}", t0.elapsed().as_millis());
+                        for (id, distance) in results {
+                            println!("");
+                            println!("   distance: {:?}", distance);
+                            println!("   document: {:?}", documents[&id].name);
+                            println!("   path:     {:?}", files[&documents[&id].file_id].path);
+                        }
+
                     }
 
                     anyhow::Ok(())
-                })
+                }}.log_err())
                 .detach();
 
             let provider = DummyEmbeddings {};
+            // let provider = OpenAIEmbeddings { client };
 
             cx.background()
                 .scoped(|scope| {