Reduced redundant database connections on each worktree change.

KCaverly and maxbrunsfeld created

Co-authored-by: maxbrunsfeld <max@zed.dev>

Change summary

crates/vector_store/src/db.rs           |  78 +++++-
crates/vector_store/src/vector_store.rs | 282 +++++++++++---------------
2 files changed, 182 insertions(+), 178 deletions(-)

Detailed changes

crates/vector_store/src/db.rs 🔗

@@ -1,4 +1,5 @@
 use std::{
+    cmp::Ordering,
     collections::HashMap,
     path::{Path, PathBuf},
     rc::Rc,
@@ -14,16 +15,6 @@ use rusqlite::{
     types::{FromSql, FromSqlResult, ValueRef},
 };
 
-// Note this is not an appropriate document
-#[derive(Debug)]
-pub struct DocumentRecord {
-    pub id: usize,
-    pub file_id: usize,
-    pub offset: usize,
-    pub name: String,
-    pub embedding: Embedding,
-}
-
 #[derive(Debug)]
 pub struct FileRecord {
     pub id: usize,
@@ -32,7 +23,7 @@ pub struct FileRecord {
 }
 
 #[derive(Debug)]
-pub struct Embedding(pub Vec<f32>);
+struct Embedding(pub Vec<f32>);
 
 impl FromSql for Embedding {
     fn column_result(value: ValueRef) -> FromSqlResult<Self> {
@@ -205,10 +196,35 @@ impl VectorDatabase {
         Ok(result)
     }
 
-    pub fn for_each_document(
+    pub fn top_k_search(
+        &self,
+        worktree_ids: &[i64],
+        query_embedding: &Vec<f32>,
+        limit: usize,
+    ) -> Result<Vec<(i64, PathBuf, usize, String)>> {
+        let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
+        self.for_each_document(&worktree_ids, |id, embedding| {
+            eprintln!("document {id} {embedding:?}");
+
+            let similarity = dot(&embedding, &query_embedding);
+            let ix = match results
+                .binary_search_by(|(_, s)| similarity.partial_cmp(&s).unwrap_or(Ordering::Equal))
+            {
+                Ok(ix) => ix,
+                Err(ix) => ix,
+            };
+            results.insert(ix, (id, similarity));
+            results.truncate(limit);
+        })?;
+
+        let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
+        self.get_documents_by_ids(&ids)
+    }
+
+    fn for_each_document(
         &self,
         worktree_ids: &[i64],
-        mut f: impl FnMut(i64, Embedding),
+        mut f: impl FnMut(i64, Vec<f32>),
     ) -> Result<()> {
         let mut query_statement = self.db.prepare(
             "
@@ -221,16 +237,20 @@ impl VectorDatabase {
                 files.worktree_id IN rarray(?)
             ",
         )?;
+
         query_statement
             .query_map(params![ids_to_sql(worktree_ids)], |row| {
-                Ok((row.get(0)?, row.get(1)?))
+                Ok((row.get(0)?, row.get::<_, Embedding>(1)?))
             })?
             .filter_map(|row| row.ok())
-            .for_each(|row| f(row.0, row.1));
+            .for_each(|(id, embedding)| {
+                dbg!("id");
+                f(id, embedding.0)
+            });
         Ok(())
     }
 
-    pub fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
+    fn get_documents_by_ids(&self, ids: &[i64]) -> Result<Vec<(i64, PathBuf, usize, String)>> {
         let mut statement = self.db.prepare(
             "
                 SELECT
@@ -279,3 +299,29 @@ fn ids_to_sql(ids: &[i64]) -> Rc<Vec<rusqlite::types::Value>> {
             .collect::<Vec<_>>(),
     )
 }
+
+pub(crate) fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
+    let len = vec_a.len();
+    assert_eq!(len, vec_b.len());
+
+    let mut result = 0.0;
+    unsafe {
+        matrixmultiply::sgemm(
+            1,
+            len,
+            1,
+            1.0,
+            vec_a.as_ptr(),
+            len as isize,
+            1,
+            vec_b.as_ptr(),
+            1,
+            len as isize,
+            0.0,
+            &mut result as *mut f32,
+            1,
+            1,
+        );
+    }
+    result
+}

crates/vector_store/src/vector_store.rs 🔗

@@ -20,7 +20,6 @@ use parsing::{CodeContextRetriever, ParsedFile};
 use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
 use smol::channel;
 use std::{
-    cmp::Ordering,
     collections::HashMap,
     path::{Path, PathBuf},
     sync::Arc,
@@ -112,10 +111,10 @@ pub struct VectorStore {
     database_url: Arc<PathBuf>,
     embedding_provider: Arc<dyn EmbeddingProvider>,
     language_registry: Arc<LanguageRegistry>,
-    db_update_tx: channel::Sender<DbWrite>,
+    db_update_tx: channel::Sender<DbOperation>,
     parsing_files_tx: channel::Sender<PendingFile>,
     _db_update_task: Task<()>,
-    _embed_batch_task: Vec<Task<()>>,
+    _embed_batch_task: Task<()>,
     _batch_files_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
@@ -128,6 +127,30 @@ struct ProjectState {
 }
 
 impl ProjectState {
+    fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
+        self.worktree_db_ids
+            .iter()
+            .find_map(|(worktree_id, db_id)| {
+                if *worktree_id == id {
+                    Some(*db_id)
+                } else {
+                    None
+                }
+            })
+    }
+
+    fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
+        self.worktree_db_ids
+            .iter()
+            .find_map(|(worktree_id, db_id)| {
+                if *db_id == id {
+                    Some(*worktree_id)
+                } else {
+                    None
+                }
+            })
+    }
+
     fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
         // If Pending File Already Exists, Replace it with the new one
         // but keep the old indexing time
@@ -185,7 +208,7 @@ pub struct SearchResult {
     pub file_path: PathBuf,
 }
 
-enum DbWrite {
+enum DbOperation {
     InsertFile {
         worktree_id: i64,
         indexed_file: ParsedFile,
@@ -198,6 +221,10 @@ enum DbWrite {
         path: PathBuf,
         sender: oneshot::Sender<Result<i64>>,
     },
+    FileMTimes {
+        worktree_id: i64,
+        sender: oneshot::Sender<Result<HashMap<PathBuf, SystemTime>>>,
+    },
 }
 
 enum EmbeddingJob {
@@ -243,20 +270,27 @@ impl VectorStore {
             let _db_update_task = cx.background().spawn(async move {
                 while let Ok(job) = db_update_rx.recv().await {
                     match job {
-                        DbWrite::InsertFile {
+                        DbOperation::InsertFile {
                             worktree_id,
                             indexed_file,
                         } => {
                             log::info!("Inserting Data for {:?}", &indexed_file.path);
                             db.insert_file(worktree_id, indexed_file).log_err();
                         }
-                        DbWrite::Delete { worktree_id, path } => {
+                        DbOperation::Delete { worktree_id, path } => {
                             db.delete_file(worktree_id, path).log_err();
                         }
-                        DbWrite::FindOrCreateWorktree { path, sender } => {
+                        DbOperation::FindOrCreateWorktree { path, sender } => {
                             let id = db.find_or_create_worktree(&path);
                             sender.send(id).ok();
                         }
+                        DbOperation::FileMTimes {
+                            worktree_id: worktree_db_id,
+                            sender,
+                        } => {
+                            let file_mtimes = db.get_file_mtimes(worktree_db_id);
+                            sender.send(file_mtimes).ok();
+                        }
                     }
                 }
             });
@@ -264,24 +298,18 @@ impl VectorStore {
             // embed_tx/rx: Embed Batch and Send to Database
             let (embed_batch_tx, embed_batch_rx) =
                 channel::unbounded::<Vec<(i64, ParsedFile, Vec<String>)>>();
-            let mut _embed_batch_task = Vec::new();
-            for _ in 0..1 {
-                //cx.background().num_cpus() {
+            let _embed_batch_task = cx.background().spawn({
                 let db_update_tx = db_update_tx.clone();
-                let embed_batch_rx = embed_batch_rx.clone();
                 let embedding_provider = embedding_provider.clone();
-                _embed_batch_task.push(cx.background().spawn(async move {
-                    while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
+                async move {
+                    while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
                         // Construct Batch
-                        let mut embeddings_queue = embeddings_queue.clone();
                         let mut document_spans = vec![];
-                        for (_, _, document_span) in embeddings_queue.clone().into_iter() {
-                            document_spans.extend(document_span);
+                        for (_, _, document_span) in embeddings_queue.iter() {
+                            document_spans.extend(document_span.iter().map(|s| s.as_str()));
                         }
 
-                        if let Ok(embeddings) = embedding_provider
-                            .embed_batch(document_spans.iter().map(|x| &**x).collect())
-                            .await
+                        if let Ok(embeddings) = embedding_provider.embed_batch(document_spans).await
                         {
                             let mut i = 0;
                             let mut j = 0;
@@ -306,7 +334,7 @@ impl VectorStore {
                                 }
 
                                 db_update_tx
-                                    .send(DbWrite::InsertFile {
+                                    .send(DbOperation::InsertFile {
                                         worktree_id,
                                         indexed_file,
                                     })
@@ -315,8 +343,9 @@ impl VectorStore {
                             }
                         }
                     }
-                }))
-            }
+                }
+            });
+
             // batch_tx/rx: Batch Files to Send for Embeddings
             let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
             let _batch_files_task = cx.background().spawn(async move {
@@ -398,7 +427,21 @@ impl VectorStore {
     fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
         let (tx, rx) = oneshot::channel();
         self.db_update_tx
-            .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
+            .try_send(DbOperation::FindOrCreateWorktree { path, sender: tx })
+            .unwrap();
+        async move { rx.await? }
+    }
+
+    fn get_file_mtimes(
+        &self,
+        worktree_id: i64,
+    ) -> impl Future<Output = Result<HashMap<PathBuf, SystemTime>>> {
+        let (tx, rx) = oneshot::channel();
+        self.db_update_tx
+            .try_send(DbOperation::FileMTimes {
+                worktree_id,
+                sender: tx,
+            })
             .unwrap();
         async move { rx.await? }
     }
@@ -450,26 +493,17 @@ impl VectorStore {
                     .collect::<Vec<_>>()
             });
 
-            // Here we query the worktree ids, and yet we dont have them elsewhere
-            // We likely want to clean up these datastructures
-            let (mut worktree_file_times, db_ids_by_worktree_id) = cx
-                .background()
-                .spawn({
-                    let worktrees = worktrees.clone();
-                    async move {
-                        let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
-                        let mut db_ids_by_worktree_id = HashMap::new();
-                        let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
-                            HashMap::new();
-                        for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
-                            let db_id = db_id?;
-                            db_ids_by_worktree_id.insert(worktree.id(), db_id);
-                            file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
-                        }
-                        anyhow::Ok((file_times, db_ids_by_worktree_id))
-                    }
-                })
-                .await?;
+            let mut worktree_file_times = HashMap::new();
+            let mut db_ids_by_worktree_id = HashMap::new();
+            for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
+                let db_id = db_id?;
+                db_ids_by_worktree_id.insert(worktree.id(), db_id);
+                worktree_file_times.insert(
+                    worktree.id(),
+                    this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
+                        .await?,
+                );
+            }
 
             cx.background()
                 .spawn({
@@ -520,7 +554,7 @@ impl VectorStore {
                             }
                             for file in file_mtimes.keys() {
                                 db_update_tx
-                                    .try_send(DbWrite::Delete {
+                                    .try_send(DbOperation::Delete {
                                         worktree_id: db_ids_by_worktree_id[&worktree.id()],
                                         path: file.to_owned(),
                                     })
@@ -542,7 +576,7 @@ impl VectorStore {
                 // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
                 let _subscription = cx.subscribe(&project, |this, project, event, cx| {
                     if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
-                        this.project_entries_changed(project, changes, cx, worktree_id);
+                        this.project_entries_changed(project, changes.clone(), cx, worktree_id);
                     }
                 });
 
@@ -578,16 +612,7 @@ impl VectorStore {
             .worktrees(cx)
             .filter_map(|worktree| {
                 let worktree_id = worktree.read(cx).id();
-                project_state
-                    .worktree_db_ids
-                    .iter()
-                    .find_map(|(id, db_id)| {
-                        if *id == worktree_id {
-                            Some(*db_id)
-                        } else {
-                            None
-                        }
-                    })
+                project_state.db_id_for_worktree_id(worktree_id)
             })
             .collect::<Vec<_>>();
 
@@ -606,24 +631,12 @@ impl VectorStore {
                         .next()
                         .unwrap();
 
-                    let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
-                    database.for_each_document(&worktree_db_ids, |id, embedding| {
-                        let similarity = dot(&embedding.0, &phrase_embedding);
-                        let ix = match results.binary_search_by(|(_, s)| {
-                            similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
-                        }) {
-                            Ok(ix) => ix,
-                            Err(ix) => ix,
-                        };
-                        results.insert(ix, (id, similarity));
-                        results.truncate(limit);
-                    })?;
-
-                    let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<_>>();
-                    database.get_documents_by_ids(&ids)
+                    database.top_k_search(&worktree_db_ids, &phrase_embedding, limit)
                 })
                 .await?;
 
+            dbg!(&documents);
+
             this.read_with(&cx, |this, _| {
                 let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
                     state
@@ -634,17 +647,7 @@ impl VectorStore {
                 Ok(documents
                     .into_iter()
                     .filter_map(|(worktree_db_id, file_path, offset, name)| {
-                        let worktree_id =
-                            project_state
-                                .worktree_db_ids
-                                .iter()
-                                .find_map(|(id, db_id)| {
-                                    if *db_id == worktree_db_id {
-                                        Some(*id)
-                                    } else {
-                                        None
-                                    }
-                                })?;
+                        let worktree_id = project_state.worktree_id_for_db_id(worktree_db_id)?;
                         Some(SearchResult {
                             worktree_id,
                             name,
@@ -660,51 +663,36 @@ impl VectorStore {
     fn project_entries_changed(
         &mut self,
         project: ModelHandle<Project>,
-        changes: &[(Arc<Path>, ProjectEntryId, PathChange)],
+        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
         cx: &mut ModelContext<'_, VectorStore>,
         worktree_id: &WorktreeId,
     ) -> Option<()> {
-        let project_state = self.projects.get_mut(&project.downgrade())?;
-        let worktree_db_ids = project_state.worktree_db_ids.clone();
-        let worktree = project.read(cx).worktree_for_id(worktree_id.clone(), cx)?;
-
-        // Get Database
-        let (file_mtimes, worktree_db_id) = {
-            if let Ok(db) = VectorDatabase::new(self.database_url.to_string_lossy().into()) {
-                let worktree_db_id = {
-                    let mut found_db_id = None;
-                    for (w_id, db_id) in worktree_db_ids.into_iter() {
-                        if &w_id == &worktree.read(cx).id() {
-                            found_db_id = Some(db_id)
-                        }
-                    }
-                    found_db_id
-                }?;
-
-                let file_mtimes = db.get_file_mtimes(worktree_db_id).log_err()?;
+        let worktree = project
+            .read(cx)
+            .worktree_for_id(worktree_id.clone(), cx)?
+            .read(cx)
+            .snapshot();
 
-                Some((file_mtimes, worktree_db_id))
-            } else {
-                return None;
-            }
-        }?;
+        let worktree_db_id = self
+            .projects
+            .get(&project.downgrade())?
+            .db_id_for_worktree_id(worktree.id())?;
+        let file_mtimes = self.get_file_mtimes(worktree_db_id);
 
-        // Iterate Through Changes
         let language_registry = self.language_registry.clone();
-        let parsing_files_tx = self.parsing_files_tx.clone();
 
-        smol::block_on(async move {
+        cx.spawn(|this, mut cx| async move {
+            let file_mtimes = file_mtimes.await.log_err()?;
+
             for change in changes.into_iter() {
                 let change_path = change.0.clone();
-                let absolute_path = worktree.read(cx).absolutize(&change_path);
+                let absolute_path = worktree.absolutize(&change_path);
                 // Skip if git ignored or symlink
-                if let Some(entry) = worktree.read(cx).entry_for_id(change.1) {
-                    if entry.is_ignored || entry.is_symlink {
+                if let Some(entry) = worktree.entry_for_id(change.1) {
+                    if entry.is_ignored || entry.is_symlink || entry.is_external {
                         continue;
-                    } else {
-                        log::info!("Testing for Reindexing: {:?}", &change_path);
                     }
-                };
+                }
 
                 if let Ok(language) = language_registry
                     .language_for_file(&change_path.to_path_buf(), None)
@@ -718,27 +706,18 @@ impl VectorStore {
                         continue;
                     }
 
-                    if let Some(modified_time) = {
-                        let metadata = change_path.metadata();
-                        if metadata.is_err() {
-                            None
-                        } else {
-                            let mtime = metadata.unwrap().modified();
-                            if mtime.is_err() {
-                                None
-                            } else {
-                                Some(mtime.unwrap())
-                            }
-                        }
-                    } {
-                        let existing_time = file_mtimes.get(&change_path.to_path_buf());
-                        let already_stored = existing_time
-                            .map_or(false, |existing_time| &modified_time != existing_time);
+                    let modified_time = change_path.metadata().log_err()?.modified().log_err()?;
 
-                        let reindex_time =
-                            modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS);
+                    let existing_time = file_mtimes.get(&change_path.to_path_buf());
+                    let already_stored = existing_time
+                        .map_or(false, |existing_time| &modified_time != existing_time);
 
-                        if !already_stored {
+                    if !already_stored {
+                        this.update(&mut cx, |this, _| {
+                            let reindex_time =
+                                modified_time + Duration::from_secs(REINDEXING_DELAY_SECONDS);
+
+                            let project_state = this.projects.get_mut(&project.downgrade())?;
                             project_state.update_pending_files(
                                 PendingFile {
                                     relative_path: change_path.to_path_buf(),
@@ -751,13 +730,18 @@ impl VectorStore {
                             );
 
                             for file in project_state.get_outstanding_files() {
-                                parsing_files_tx.try_send(file).unwrap();
+                                this.parsing_files_tx.try_send(file).unwrap();
                             }
-                        }
+                            Some(())
+                        });
                     }
                 }
             }
-        });
+
+            Some(())
+        })
+        .detach();
+
         Some(())
     }
 }
@@ -765,29 +749,3 @@ impl VectorStore {
 impl Entity for VectorStore {
     type Event = ();
 }
-
-fn dot(vec_a: &[f32], vec_b: &[f32]) -> f32 {
-    let len = vec_a.len();
-    assert_eq!(len, vec_b.len());
-
-    let mut result = 0.0;
-    unsafe {
-        matrixmultiply::sgemm(
-            1,
-            len,
-            1,
-            1.0,
-            vec_a.as_ptr(),
-            len as isize,
-            1,
-            vec_b.as_ptr(),
-            1,
-            len as isize,
-            0.0,
-            &mut result as *mut f32,
-            1,
-            1,
-        );
-    }
-    result
-}