Working incremental index engine, with streaming similarity search!

KCaverly and maxbrunsfeld created

Co-authored-by: maxbrunsfeld <max@zed.dev>

Change summary

Cargo.lock                                    |   1 
crates/vector_store/Cargo.toml                |   3 
crates/vector_store/src/db.rs                 | 184 ++++++++++++++++----
crates/vector_store/src/vector_store.rs       | 170 +++++++++++-------
crates/vector_store/src/vector_store_tests.rs |  23 +
5 files changed, 269 insertions(+), 112 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -7967,6 +7967,7 @@ dependencies = [
  "serde_json",
  "sha-1 0.10.1",
  "smol",
+ "tempdir",
  "tree-sitter",
  "tree-sitter-rust",
  "unindent",

crates/vector_store/Cargo.toml 🔗

@@ -17,7 +17,7 @@ util = { path = "../util" }
 anyhow.workspace = true
 futures.workspace = true
 smol.workspace = true
-rusqlite = { version = "0.27.0", features=["blob"] }
+rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }
 isahc.workspace = true
 log.workspace = true
 tree-sitter.workspace = true
@@ -38,3 +38,4 @@ workspace = { path = "../workspace", features = ["test-support"] }
 tree-sitter-rust = "*"
 rand.workspace = true
 unindent.workspace = true
+tempdir.workspace = true

crates/vector_store/src/db.rs 🔗

@@ -7,9 +7,10 @@ use anyhow::{anyhow, Result};
 
 use rusqlite::{
     params,
-    types::{FromSql, FromSqlResult, ValueRef},
-    Connection,
+    types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef},
+    ToSql,
 };
+use sha1::{Digest, Sha1};
 
 use crate::IndexedFile;
 
@@ -32,7 +33,60 @@ pub struct DocumentRecord {
 pub struct FileRecord {
     pub id: usize,
     pub relative_path: String,
-    pub sha1: String,
+    pub sha1: FileSha1,
+}
+
+#[derive(Debug)]
+pub struct FileSha1(pub Vec<u8>);
+
+impl FileSha1 {
+    pub fn from_str(content: String) -> Self {
+        let mut hasher = Sha1::new();
+        hasher.update(content);
+        let sha1 = hasher.finalize()[..]
+            .into_iter()
+            .map(|val| val.to_owned())
+            .collect::<Vec<u8>>();
+        return FileSha1(sha1);
+    }
+
+    pub fn equals(&self, content: &String) -> bool {
+        let mut hasher = Sha1::new();
+        hasher.update(content);
+        let sha1 = hasher.finalize()[..]
+            .into_iter()
+            .map(|val| val.to_owned())
+            .collect::<Vec<u8>>();
+
+        let equal = self
+            .0
+            .clone()
+            .into_iter()
+            .zip(sha1)
+            .filter(|&(a, b)| a == b)
+            .count()
+            == self.0.len();
+
+        equal
+    }
+}
+
+impl ToSql for FileSha1 {
+    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
+        return self.0.to_sql();
+    }
+}
+
+impl FromSql for FileSha1 {
+    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+        let bytes = value.as_blob()?;
+        Ok(FileSha1(
+            bytes
+                .into_iter()
+                .map(|val| val.to_owned())
+                .collect::<Vec<u8>>(),
+        ))
+    }
 }
 
 #[derive(Debug)]
@@ -63,6 +117,8 @@ impl VectorDatabase {
     }
 
     fn initialize_database(&self) -> Result<()> {
+        rusqlite::vtab::array::load_module(&self.db)?;
+
         // This will create the database if it doesnt exist
 
         // Initialize Vector Databasing Tables
@@ -81,7 +137,7 @@ impl VectorDatabase {
                 id INTEGER PRIMARY KEY AUTOINCREMENT,
                 worktree_id INTEGER NOT NULL,
                 relative_path VARCHAR NOT NULL,
-                sha1 NVARCHAR(40) NOT NULL,
+                sha1 BLOB NOT NULL,
                 FOREIGN KEY(worktree_id) REFERENCES worktrees(id) ON DELETE CASCADE
             )",
             [],
@@ -102,30 +158,23 @@ impl VectorDatabase {
         Ok(())
     }
 
-    // pub async fn get_or_create_project(project_path: PathBuf) -> Result<usize> {
-    //     // Check if we have the project, if we do, return the ID
-    //     // If we do not have the project, insert the project and return the ID
-
-    //     let db = rusqlite::Connection::open(VECTOR_DB_URL)?;
-
-    //     let projects_query = db.prepare(&format!(
-    //         "SELECT id FROM projects WHERE path = {}",
-    //         project_path.to_str().unwrap() // This is unsafe
-    //     ))?;
-
-    //     let project_id = db.last_insert_rowid();
-
-    //     return Ok(project_id as usize);
-    // }
-
-    pub fn insert_file(&self, indexed_file: IndexedFile) -> Result<()> {
+    pub fn insert_file(&self, worktree_id: i64, indexed_file: IndexedFile) -> Result<()> {
         // Write to files table, and return generated id.
-        let files_insert = self.db.execute(
-            "INSERT INTO files (relative_path, sha1) VALUES (?1, ?2)",
-            params![indexed_file.path.to_str(), indexed_file.sha1],
+        log::info!("Inserting File!");
+        self.db.execute(
+            "
+            DELETE FROM files WHERE worktree_id = ?1 AND relative_path = ?2;
+            ",
+            params![worktree_id, indexed_file.path.to_str()],
+        )?;
+        self.db.execute(
+            "
+            INSERT INTO files (worktree_id, relative_path, sha1) VALUES (?1, ?2, $3);
+            ",
+            params![worktree_id, indexed_file.path.to_str(), indexed_file.sha1],
         )?;
 
-        let inserted_id = self.db.last_insert_rowid();
+        let file_id = self.db.last_insert_rowid();
 
         // Currently inserting at approximately 3400 documents a second
         // I imagine we can speed this up with a bulk insert of some kind.
@@ -135,7 +184,7 @@ impl VectorDatabase {
             self.db.execute(
                 "INSERT INTO documents (file_id, offset, name, embedding) VALUES (?1, ?2, ?3, ?4)",
                 params![
-                    inserted_id,
+                    file_id,
                     document.offset.to_string(),
                     document.name,
                     embedding_blob
@@ -147,25 +196,41 @@ impl VectorDatabase {
     }
 
     pub fn find_or_create_worktree(&self, worktree_root_path: &Path) -> Result<i64> {
+        // Check that the absolute path doesnt exist
+        let mut worktree_query = self
+            .db
+            .prepare("SELECT id FROM worktrees WHERE absolute_path = ?1")?;
+
+        let worktree_id = worktree_query
+            .query_row(params![worktree_root_path.to_string_lossy()], |row| {
+                Ok(row.get::<_, i64>(0)?)
+            })
+            .map_err(|err| anyhow!(err));
+
+        if worktree_id.is_ok() {
+            return worktree_id;
+        }
+
+        // If worktree_id is Err, insert new worktree
         self.db.execute(
             "
             INSERT into worktrees (absolute_path) VALUES (?1)
-            ON CONFLICT DO NOTHING
             ",
             params![worktree_root_path.to_string_lossy()],
         )?;
         Ok(self.db.last_insert_rowid())
     }
 
-    pub fn get_file_hashes(&self, worktree_id: i64) -> Result<Vec<(PathBuf, String)>> {
-        let mut statement = self
-            .db
-            .prepare("SELECT relative_path, sha1 FROM files ORDER BY relative_path")?;
-        let mut result = Vec::new();
-        for row in
-            statement.query_map([], |row| Ok((row.get::<_, String>(0)?.into(), row.get(1)?)))?
-        {
-            result.push(row?);
+    pub fn get_file_hashes(&self, worktree_id: i64) -> Result<HashMap<PathBuf, FileSha1>> {
+        let mut statement = self.db.prepare(
+            "SELECT relative_path, sha1 FROM files WHERE worktree_id = ?1 ORDER BY relative_path",
+        )?;
+        let mut result: HashMap<PathBuf, FileSha1> = HashMap::new();
+        for row in statement.query_map(params![worktree_id], |row| {
+            Ok((row.get::<_, String>(0)?.into(), row.get(1)?))
+        })? {
+            let row = row?;
+            result.insert(row.0, row.1);
         }
         Ok(result)
     }
@@ -204,6 +269,53 @@ impl VectorDatabase {
         Ok(())
     }
 
+    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(PathBuf, usize, String)>> {
+        let mut statement = self.db.prepare(
+            "
+                SELECT
+                    documents.id, files.relative_path, documents.offset, documents.name
+                FROM
+                    documents, files
+                WHERE
+                    documents.file_id = files.id AND
+                    documents.id in rarray(?)
+            ",
+        )?;
+
+        let result_iter = statement.query_map(
+            params![std::rc::Rc::new(
+                ids.iter()
+                    .copied()
+                    .map(|v| rusqlite::types::Value::from(v))
+                    .collect::<Vec<_>>()
+            )],
+            |row| {
+                Ok((
+                    row.get::<_, i64>(0)?,
+                    row.get::<_, String>(1)?.into(),
+                    row.get(2)?,
+                    row.get(3)?,
+                ))
+            },
+        )?;
+
+        let mut values_by_id = HashMap::<i64, (PathBuf, usize, String)>::default();
+        for row in result_iter {
+            let (id, path, offset, name) = row?;
+            values_by_id.insert(id, (path, offset, name));
+        }
+
+        let mut results = Vec::with_capacity(ids.len());
+        for id in ids {
+            let (path, offset, name) = values_by_id
+                .remove(id)
+                .ok_or(anyhow!("missing document id {}", id))?;
+            results.push((path, offset, name));
+        }
+
+        Ok(results)
+    }
+
     pub fn get_documents(&self) -> Result<HashMap<usize, DocumentRecord>> {
         let mut query_statement = self
             .db

crates/vector_store/src/vector_store.rs 🔗

@@ -7,15 +7,14 @@ mod search;
 mod vector_store_tests;
 
 use anyhow::{anyhow, Result};
-use db::{VectorDatabase, VECTOR_DB_URL};
-use embedding::{DummyEmbeddings, EmbeddingProvider, OpenAIEmbeddings};
+use db::{FileSha1, VectorDatabase, VECTOR_DB_URL};
+use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
-use language::LanguageRegistry;
+use language::{Language, LanguageRegistry};
 use parsing::Document;
 use project::{Fs, Project};
-use search::{BruteForceSearch, VectorSearch};
 use smol::channel;
-use std::{cmp::Ordering, path::PathBuf, sync::Arc, time::Instant};
+use std::{cmp::Ordering, collections::HashMap, path::PathBuf, sync::Arc};
 use tree_sitter::{Parser, QueryCursor};
 use util::{http::HttpClient, ResultExt, TryFutureExt};
 use workspace::WorkspaceCreated;
@@ -45,7 +44,7 @@ pub fn init(
                 let project = workspace.read(cx).project().clone();
                 if project.read(cx).is_local() {
                     vector_store.update(cx, |store, cx| {
-                        store.add_project(project, cx);
+                        store.add_project(project, cx).detach();
                     });
                 }
             }
@@ -57,16 +56,10 @@ pub fn init(
 #[derive(Debug)]
 pub struct IndexedFile {
     path: PathBuf,
-    sha1: String,
+    sha1: FileSha1,
     documents: Vec<Document>,
 }
 
-// struct SearchResult {
-//     path: PathBuf,
-//     offset: usize,
-//     name: String,
-//     distance: f32,
-// }
 struct VectorStore {
     fs: Arc<dyn Fs>,
     database_url: Arc<str>,
@@ -99,20 +92,10 @@ impl VectorStore {
         cursor: &mut QueryCursor,
         parser: &mut Parser,
         embedding_provider: &dyn EmbeddingProvider,
-        language_registry: &Arc<LanguageRegistry>,
+        language: Arc<Language>,
         file_path: PathBuf,
         content: String,
     ) -> Result<IndexedFile> {
-        dbg!(&file_path, &content);
-
-        let language = language_registry
-            .language_for_file(&file_path, None)
-            .await?;
-
-        if language.name().as_ref() != "Rust" {
-            Err(anyhow!("unsupported language"))?;
-        }
-
         let grammar = language.grammar().ok_or_else(|| anyhow!("no grammar"))?;
         let outline_config = grammar
             .outline_config
@@ -156,9 +139,11 @@ impl VectorStore {
             document.embedding = embedding;
         }
 
+        let sha1 = FileSha1::from_str(content);
+
         return Ok(IndexedFile {
             path: file_path,
-            sha1: String::new(),
+            sha1,
             documents,
         });
     }
@@ -171,7 +156,13 @@ impl VectorStore {
         let worktree_scans_complete = project
             .read(cx)
             .worktrees(cx)
-            .map(|worktree| worktree.read(cx).as_local().unwrap().scan_complete())
+            .map(|worktree| {
+                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
+                async move {
+                    scan_complete.await;
+                    log::info!("worktree scan completed");
+                }
+            })
             .collect::<Vec<_>>();
 
         let fs = self.fs.clone();
@@ -182,6 +173,13 @@ impl VectorStore {
         cx.spawn(|_, cx| async move {
             futures::future::join_all(worktree_scans_complete).await;
 
+            // TODO: remove this after fixing the bug in scan_complete
+            cx.background()
+                .timer(std::time::Duration::from_secs(3))
+                .await;
+
+            let db = VectorDatabase::new(&database_url)?;
+
             let worktrees = project.read_with(&cx, |project, cx| {
                 project
                     .worktrees(cx)
@@ -189,37 +187,74 @@ impl VectorStore {
                     .collect::<Vec<_>>()
             });
 
-            let db = VectorDatabase::new(&database_url)?;
             let worktree_root_paths = worktrees
                 .iter()
                 .map(|worktree| worktree.abs_path().clone())
                 .collect::<Vec<_>>();
-            let (db, file_hashes) = cx
+
+            // Here we query the worktree ids, and yet we dont have them elsewhere
+            // We likely want to clean up these datastructures
+            let (db, worktree_hashes, worktree_ids) = cx
                 .background()
                 .spawn(async move {
-                    let mut hashes = Vec::new();
+                    let mut worktree_ids: HashMap<PathBuf, i64> = HashMap::new();
+                    let mut hashes: HashMap<i64, HashMap<PathBuf, FileSha1>> = HashMap::new();
                     for worktree_root_path in worktree_root_paths {
                         let worktree_id =
                             db.find_or_create_worktree(worktree_root_path.as_ref())?;
-                        hashes.push((worktree_id, db.get_file_hashes(worktree_id)?));
+                        worktree_ids.insert(worktree_root_path.to_path_buf(), worktree_id);
+                        hashes.insert(worktree_id, db.get_file_hashes(worktree_id)?);
                     }
-                    anyhow::Ok((db, hashes))
+                    anyhow::Ok((db, hashes, worktree_ids))
                 })
                 .await?;
 
-            let (paths_tx, paths_rx) = channel::unbounded::<(i64, PathBuf, String)>();
-            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<IndexedFile>();
+            let (paths_tx, paths_rx) =
+                channel::unbounded::<(i64, PathBuf, String, Arc<Language>)>();
+            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
             cx.background()
                 .spawn({
                     let fs = fs.clone();
                     async move {
                         for worktree in worktrees.into_iter() {
+                            let worktree_id = worktree_ids[&worktree.abs_path().to_path_buf()];
+                            let file_hashes = &worktree_hashes[&worktree_id];
                             for file in worktree.files(false, 0) {
                                 let absolute_path = worktree.absolutize(&file.path);
-                                dbg!(&absolute_path);
-                                if let Some(content) = fs.load(&absolute_path).await.log_err() {
-                                    dbg!(&content);
-                                    paths_tx.try_send((0, absolute_path, content)).unwrap();
+
+                                if let Ok(language) = language_registry
+                                    .language_for_file(&absolute_path, None)
+                                    .await
+                                {
+                                    if language.name().as_ref() != "Rust" {
+                                        continue;
+                                    }
+
+                                    if let Some(content) = fs.load(&absolute_path).await.log_err() {
+                                        log::info!("loaded file: {absolute_path:?}");
+
+                                        let path_buf = file.path.to_path_buf();
+                                        let already_stored = file_hashes
+                                            .get(&path_buf)
+                                            .map_or(false, |existing_hash| {
+                                                existing_hash.equals(&content)
+                                            });
+
+                                        if !already_stored {
+                                            log::info!(
+                                                "File Changed (Sending to Parse): {:?}",
+                                                &path_buf
+                                            );
+                                            paths_tx
+                                                .try_send((
+                                                    worktree_id,
+                                                    path_buf,
+                                                    content,
+                                                    language,
+                                                ))
+                                                .unwrap();
+                                        }
+                                    }
                                 }
                             }
                         }
@@ -230,8 +265,8 @@ impl VectorStore {
             let db_write_task = cx.background().spawn(
                 async move {
                     // Initialize Database, creates database and tables if not exists
-                    while let Ok(indexed_file) = indexed_files_rx.recv().await {
-                        db.insert_file(indexed_file).log_err();
+                    while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
+                        db.insert_file(worktree_id, indexed_file).log_err();
                     }
 
                     // ALL OF THE BELOW IS FOR TESTING,
@@ -271,29 +306,29 @@ impl VectorStore {
                 .log_err(),
             );
 
-            let provider = DummyEmbeddings {};
-            // let provider = OpenAIEmbeddings { client };
-
             cx.background()
                 .scoped(|scope| {
                     for _ in 0..cx.background().num_cpus() {
                         scope.spawn(async {
                             let mut parser = Parser::new();
                             let mut cursor = QueryCursor::new();
-                            while let Ok((worktree_id, file_path, content)) = paths_rx.recv().await
+                            while let Ok((worktree_id, file_path, content, language)) =
+                                paths_rx.recv().await
                             {
                                 if let Some(indexed_file) = Self::index_file(
                                     &mut cursor,
                                     &mut parser,
-                                    &provider,
-                                    &language_registry,
+                                    embedding_provider.as_ref(),
+                                    language,
                                     file_path,
                                     content,
                                 )
                                 .await
                                 .log_err()
                                 {
-                                    indexed_files_tx.try_send(indexed_file).unwrap();
+                                    indexed_files_tx
+                                        .try_send((worktree_id, indexed_file))
+                                        .unwrap();
                                 }
                             }
                         });
@@ -315,41 +350,42 @@ impl VectorStore {
     ) -> Task<Result<Vec<SearchResult>>> {
         let embedding_provider = self.embedding_provider.clone();
         let database_url = self.database_url.clone();
-        cx.spawn(|this, cx| async move {
+        cx.background().spawn(async move {
             let database = VectorDatabase::new(database_url.as_ref())?;
 
-            // let embedding = embedding_provider.embed_batch(vec![&phrase]).await?;
-            //
-            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+            let phrase_embedding = embedding_provider
+                .embed_batch(vec![&phrase])
+                .await?
+                .into_iter()
+                .next()
+                .unwrap();
 
+            let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
             database.for_each_document(0, |id, embedding| {
-                dbg!(id, &embedding);
-
-                let similarity = dot(&embedding.0, &embedding.0);
+                let similarity = dot(&embedding.0, &phrase_embedding);
                 let ix = match results.binary_search_by(|(_, s)| {
-                    s.partial_cmp(&similarity).unwrap_or(Ordering::Equal)
+                    similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
                 }) {
                     Ok(ix) => ix,
                     Err(ix) => ix,
                 };
-
                 results.insert(ix, (id, similarity));
                 results.truncate(limit);
             })?;
 
-            dbg!(&results);
-
             let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
-            // let documents = database.get_documents_by_ids(ids)?;
-
-            // let search_provider = cx
-            //     .background()
-            //     .spawn(async move { BruteForceSearch::load(&database) })
-            //     .await?;
-
-            // let results = search_provider.top_k_search(&embedding, limit))
-
-            anyhow::Ok(vec![])
+            let documents = database.get_documents_by_ids(&ids)?;
+
+            anyhow::Ok(
+                documents
+                    .into_iter()
+                    .map(|(file_path, offset, name)| SearchResult {
+                        name,
+                        offset,
+                        file_path,
+                    })
+                    .collect(),
+            )
         })
     }
 }

crates/vector_store/src/vector_store_tests.rs 🔗

@@ -57,20 +57,26 @@ async fn test_vector_store(cx: &mut TestAppContext) {
     );
     languages.add(rust_language);
 
+    let db_dir = tempdir::TempDir::new("vector-store").unwrap();
+    let db_path = db_dir.path().join("db.sqlite");
+
     let store = cx.add_model(|_| {
         VectorStore::new(
             fs.clone(),
-            "foo".to_string(),
+            db_path.to_string_lossy().to_string(),
             Arc::new(FakeEmbeddingProvider),
             languages,
         )
     });
 
     let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
-    store
-        .update(cx, |store, cx| store.add_project(project, cx))
-        .await
-        .unwrap();
+    let add_project = store.update(cx, |store, cx| store.add_project(project, cx));
+
+    // TODO - remove
+    cx.foreground()
+        .advance_clock(std::time::Duration::from_secs(3));
+
+    add_project.await.unwrap();
 
     let search_results = store
         .update(cx, |store, cx| store.search("aaaa".to_string(), 5, cx))
@@ -78,7 +84,7 @@ async fn test_vector_store(cx: &mut TestAppContext) {
         .unwrap();
 
     assert_eq!(search_results[0].offset, 0);
-    assert_eq!(search_results[1].name, "aaa");
+    assert_eq!(search_results[0].name, "aaa");
 }
 
 #[test]
@@ -114,9 +120,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
         Ok(spans
             .iter()
             .map(|span| {
-                let mut result = vec![0.0; 26];
+                let mut result = vec![1.0; 26];
                 for letter in span.chars() {
-                    if letter as u32 > 'a' as u32 {
+                    let letter = letter.to_ascii_lowercase();
+                    if letter as u32 >= 'a' as u32 {
                         let ix = (letter as u32) - ('a' as u32);
                         if ix < 26 {
                             result[ix as usize] += 1.0;