remove reindexing subscription, and add status methods for vector store

KCaverly and maxbrunsfeld created

Co-authored-by: maxbrunsfeld <max@zed.dev>

Change summary

Cargo.lock                                    |   1 
crates/vector_store/Cargo.toml                |   1 
crates/vector_store/src/modal.rs              |   2 
crates/vector_store/src/vector_store.rs       | 379 +++++++-------------
crates/vector_store/src/vector_store_tests.rs |  78 +++
5 files changed, 208 insertions(+), 253 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -8493,6 +8493,7 @@ dependencies = [
  "lazy_static",
  "log",
  "matrixmultiply",
+ "parking_lot 0.11.2",
  "picker",
  "project",
  "rand 0.8.5",

crates/vector_store/Cargo.toml 🔗

@@ -33,6 +33,7 @@ async-trait.workspace = true
 bincode = "1.3.3"
 matrixmultiply = "0.3.7"
 tiktoken-rs = "0.5.0"
+parking_lot.workspace = true
 rand.workspace = true
 schemars.workspace = true
 

crates/vector_store/src/modal.rs 🔗

@@ -124,7 +124,7 @@ impl PickerDelegate for SemanticSearchDelegate {
             if let Some(retrieved) = retrieved_cached.log_err() {
                 if !retrieved {
                     let task = vector_store.update(&mut cx, |store, cx| {
-                        store.search(project.clone(), query.to_string(), 10, cx)
+                        store.search_project(project.clone(), query.to_string(), 10, cx)
                     });
 
                     if let Some(results) = task.await.log_err() {

crates/vector_store/src/vector_store.rs 🔗

@@ -18,15 +18,19 @@ use gpui::{
 };
 use language::{Language, LanguageRegistry};
 use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
+use parking_lot::Mutex;
 use parsing::{CodeContextRetriever, Document};
-use project::{Fs, PathChange, Project, ProjectEntryId, WorktreeId};
+use project::{Fs, Project, WorktreeId};
 use smol::channel;
 use std::{
-    collections::HashMap,
+    collections::{HashMap, HashSet},
     ops::Range,
     path::{Path, PathBuf},
-    sync::Arc,
-    time::{Duration, Instant, SystemTime},
+    sync::{
+        atomic::{self, AtomicUsize},
+        Arc, Weak,
+    },
+    time::{Instant, SystemTime},
 };
 use util::{
     channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
@@ -99,7 +103,7 @@ pub fn init(
                         let project = workspace.read(cx).project().clone();
                         if project.read(cx).is_local() {
                             vector_store.update(cx, |store, cx| {
-                                store.add_project(project, cx).detach();
+                                store.index_project(project, cx).detach();
                             });
                         }
                     }
@@ -124,13 +128,20 @@ pub struct VectorStore {
     _embed_batch_task: Task<()>,
     _batch_files_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
+    next_job_id: Arc<AtomicUsize>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 }
 
 struct ProjectState {
     worktree_db_ids: Vec<(WorktreeId, i64)>,
-    pending_files: HashMap<PathBuf, (PendingFile, SystemTime)>,
-    _subscription: gpui::Subscription,
+    outstanding_jobs: Arc<Mutex<HashSet<JobId>>>,
+}
+
+type JobId = usize;
+
+struct JobHandle {
+    id: JobId,
+    set: Weak<Mutex<HashSet<JobId>>>,
 }
 
 impl ProjectState {
@@ -157,54 +168,15 @@ impl ProjectState {
                 }
             })
     }
-
-    fn update_pending_files(&mut self, pending_file: PendingFile, indexing_time: SystemTime) {
-        // If Pending File Already Exists, Replace it with the new one
-        // but keep the old indexing time
-        if let Some(old_file) = self
-            .pending_files
-            .remove(&pending_file.relative_path.clone())
-        {
-            self.pending_files.insert(
-                pending_file.relative_path.clone(),
-                (pending_file, old_file.1),
-            );
-        } else {
-            self.pending_files.insert(
-                pending_file.relative_path.clone(),
-                (pending_file, indexing_time),
-            );
-        };
-    }
-
-    fn get_outstanding_files(&mut self) -> Vec<PendingFile> {
-        let mut outstanding_files = vec![];
-        let mut remove_keys = vec![];
-        for key in self.pending_files.keys().into_iter() {
-            if let Some(pending_details) = self.pending_files.get(key) {
-                let (pending_file, index_time) = pending_details;
-                if index_time <= &SystemTime::now() {
-                    outstanding_files.push(pending_file.clone());
-                    remove_keys.push(key.clone());
-                }
-            }
-        }
-
-        for key in remove_keys.iter() {
-            self.pending_files.remove(key);
-        }
-
-        return outstanding_files;
-    }
 }
 
-#[derive(Clone, Debug)]
 pub struct PendingFile {
     worktree_db_id: i64,
     relative_path: PathBuf,
     absolute_path: PathBuf,
     language: Arc<Language>,
     modified_time: SystemTime,
+    job_handle: JobHandle,
 }
 
 #[derive(Debug, Clone)]
@@ -221,6 +193,7 @@ enum DbOperation {
         documents: Vec<Document>,
         path: PathBuf,
         mtime: SystemTime,
+        job_handle: JobHandle,
     },
     Delete {
         worktree_id: i64,
@@ -242,6 +215,7 @@ enum EmbeddingJob {
         path: PathBuf,
         mtime: SystemTime,
         documents: Vec<Document>,
+        job_handle: JobHandle,
     },
     Flush,
 }
@@ -274,9 +248,11 @@ impl VectorStore {
                             documents,
                             path,
                             mtime,
+                            job_handle,
                         } => {
                             db.insert_file(worktree_id, path, mtime, documents)
                                 .log_err();
+                            drop(job_handle)
                         }
                         DbOperation::Delete { worktree_id, path } => {
                             db.delete_file(worktree_id, path).log_err();
@@ -298,7 +274,7 @@ impl VectorStore {
 
             // embed_tx/rx: Embed Batch and Send to Database
             let (embed_batch_tx, embed_batch_rx) =
-                channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime)>>();
+                channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
             let _embed_batch_task = cx.background().spawn({
                 let db_update_tx = db_update_tx.clone();
                 let embedding_provider = embedding_provider.clone();
@@ -306,7 +282,7 @@ impl VectorStore {
                     while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
                         // Construct Batch
                         let mut batch_documents = vec![];
-                        for (_, documents, _, _) in embeddings_queue.iter() {
+                        for (_, documents, _, _, _) in embeddings_queue.iter() {
                             batch_documents
                                 .extend(documents.iter().map(|document| document.content.as_str()));
                         }
@@ -333,7 +309,7 @@ impl VectorStore {
                                 j += 1;
                             }
 
-                            for (worktree_id, documents, path, mtime) in
+                            for (worktree_id, documents, path, mtime, job_handle) in
                                 embeddings_queue.into_iter()
                             {
                                 for document in documents.iter() {
@@ -350,6 +326,7 @@ impl VectorStore {
                                         documents,
                                         path,
                                         mtime,
+                                        job_handle,
                                     })
                                     .await
                                     .unwrap();
@@ -372,9 +349,16 @@ impl VectorStore {
                             worktree_id,
                             path,
                             mtime,
+                            job_handle,
                         } => {
                             queue_len += &documents.len();
-                            embeddings_queue.push((worktree_id, documents, path, mtime));
+                            embeddings_queue.push((
+                                worktree_id,
+                                documents,
+                                path,
+                                mtime,
+                                job_handle,
+                            ));
                             queue_len >= EMBEDDINGS_BATCH_SIZE
                         }
                         EmbeddingJob::Flush => true,
@@ -420,6 +404,7 @@ impl VectorStore {
                                         worktree_id: pending_file.worktree_db_id,
                                         path: pending_file.relative_path,
                                         mtime: pending_file.modified_time,
+                                        job_handle: pending_file.job_handle,
                                         documents,
                                     })
                                     .unwrap();
@@ -439,6 +424,7 @@ impl VectorStore {
                 embedding_provider,
                 language_registry,
                 db_update_tx,
+                next_job_id: Default::default(),
                 parsing_files_tx,
                 _db_update_task,
                 _embed_batch_task,
@@ -471,11 +457,11 @@ impl VectorStore {
         async move { rx.await? }
     }
 
-    fn add_project(
+    fn index_project(
         &mut self,
         project: ModelHandle<Project>,
         cx: &mut ModelContext<Self>,
-    ) -> Task<Result<()>> {
+    ) -> Task<Result<usize>> {
         let worktree_scans_complete = project
             .read(cx)
             .worktrees(cx)
@@ -494,21 +480,16 @@ impl VectorStore {
             })
             .collect::<Vec<_>>();
 
-        let fs = self.fs.clone();
         let language_registry = self.language_registry.clone();
-        let database_url = self.database_url.clone();
         let db_update_tx = self.db_update_tx.clone();
         let parsing_files_tx = self.parsing_files_tx.clone();
+        let next_job_id = self.next_job_id.clone();
 
         cx.spawn(|this, mut cx| async move {
             futures::future::join_all(worktree_scans_complete).await;
 
             let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
 
-            if let Some(db_directory) = database_url.parent() {
-                fs.create_dir(db_directory).await.log_err();
-            }
-
             let worktrees = project.read_with(&cx, |project, cx| {
                 project
                     .worktrees(cx)
@@ -516,109 +497,115 @@ impl VectorStore {
                     .collect::<Vec<_>>()
             });
 
-            let mut worktree_file_times = HashMap::new();
+            let mut worktree_file_mtimes = HashMap::new();
             let mut db_ids_by_worktree_id = HashMap::new();
             for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
                 let db_id = db_id?;
                 db_ids_by_worktree_id.insert(worktree.id(), db_id);
-                worktree_file_times.insert(
+                worktree_file_mtimes.insert(
                     worktree.id(),
                     this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
                         .await?,
                 );
             }
 
-            cx.background()
-                .spawn({
-                    let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
-                    let db_update_tx = db_update_tx.clone();
-                    let language_registry = language_registry.clone();
-                    let parsing_files_tx = parsing_files_tx.clone();
-                    async move {
-                        let t0 = Instant::now();
-                        for worktree in worktrees.into_iter() {
-                            let mut file_mtimes =
-                                worktree_file_times.remove(&worktree.id()).unwrap();
-                            for file in worktree.files(false, 0) {
-                                let absolute_path = worktree.absolutize(&file.path);
-
-                                if let Ok(language) = language_registry
-                                    .language_for_file(&absolute_path, None)
-                                    .await
-                                {
-                                    if language
-                                        .grammar()
-                                        .and_then(|grammar| grammar.embedding_config.as_ref())
-                                        .is_none()
-                                    {
-                                        continue;
-                                    }
-
-                                    let path_buf = file.path.to_path_buf();
-                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
-                                    let already_stored = stored_mtime
-                                        .map_or(false, |existing_mtime| {
-                                            existing_mtime == file.mtime
-                                        });
-
-                                    if !already_stored {
-                                        log::trace!("sending for parsing: {:?}", path_buf);
-                                        parsing_files_tx
-                                            .try_send(PendingFile {
-                                                worktree_db_id: db_ids_by_worktree_id
-                                                    [&worktree.id()],
-                                                relative_path: path_buf,
-                                                absolute_path,
-                                                language,
-                                                modified_time: file.mtime,
-                                            })
-                                            .unwrap();
-                                    }
-                                }
-                            }
-                            for file in file_mtimes.keys() {
-                                db_update_tx
-                                    .try_send(DbOperation::Delete {
-                                        worktree_id: db_ids_by_worktree_id[&worktree.id()],
-                                        path: file.to_owned(),
-                                    })
-                                    .unwrap();
-                            }
-                        }
-                        log::trace!(
-                            "parsing worktree completed in {:?}",
-                            t0.elapsed().as_millis()
-                        );
-                    }
-                })
-                .detach();
-
             // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
-            this.update(&mut cx, |this, cx| {
-                // The below is managing for updated on save
-                // Currently each time a file is saved, this code is run, and for all the files that were changed, if the current time is
-                // greater than the previous embedded time by the REINDEXING_DELAY variable, we will send the file off to be indexed.
-                let _subscription = cx.subscribe(&project, |this, project, event, cx| {
-                    if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
-                        this.project_entries_changed(project, changes.clone(), cx, worktree_id);
-                    }
-                });
-
+            let outstanding_jobs = Arc::new(Mutex::new(HashSet::new()));
+            this.update(&mut cx, |this, _| {
                 this.projects.insert(
                     project.downgrade(),
                     ProjectState {
-                        pending_files: HashMap::new(),
-                        worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
-                        _subscription,
+                        worktree_db_ids: db_ids_by_worktree_id
+                            .iter()
+                            .map(|(a, b)| (*a, *b))
+                            .collect(),
+                        outstanding_jobs: outstanding_jobs.clone(),
                     },
                 );
             });
 
-            anyhow::Ok(())
+            cx.background()
+                .spawn(async move {
+                    let mut count = 0;
+                    let t0 = Instant::now();
+                    for worktree in worktrees.into_iter() {
+                        let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
+                        for file in worktree.files(false, 0) {
+                            let absolute_path = worktree.absolutize(&file.path);
+
+                            if let Ok(language) = language_registry
+                                .language_for_file(&absolute_path, None)
+                                .await
+                            {
+                                if language
+                                    .grammar()
+                                    .and_then(|grammar| grammar.embedding_config.as_ref())
+                                    .is_none()
+                                {
+                                    continue;
+                                }
+
+                                let path_buf = file.path.to_path_buf();
+                                let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
+                                let already_stored = stored_mtime
+                                    .map_or(false, |existing_mtime| existing_mtime == file.mtime);
+
+                                if !already_stored {
+                                    log::trace!("sending for parsing: {:?}", path_buf);
+                                    count += 1;
+                                    let job_id = next_job_id.fetch_add(1, atomic::Ordering::SeqCst);
+                                    let job_handle = JobHandle {
+                                        id: job_id,
+                                        set: Arc::downgrade(&outstanding_jobs),
+                                    };
+                                    outstanding_jobs.lock().insert(job_id);
+                                    parsing_files_tx
+                                        .try_send(PendingFile {
+                                            worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
+                                            relative_path: path_buf,
+                                            absolute_path,
+                                            language,
+                                            job_handle,
+                                            modified_time: file.mtime,
+                                        })
+                                        .unwrap();
+                                }
+                            }
+                        }
+                        for file in file_mtimes.keys() {
+                            db_update_tx
+                                .try_send(DbOperation::Delete {
+                                    worktree_id: db_ids_by_worktree_id[&worktree.id()],
+                                    path: file.to_owned(),
+                                })
+                                .unwrap();
+                        }
+                    }
+                    log::trace!(
+                        "parsing worktree completed in {:?}",
+                        t0.elapsed().as_millis()
+                    );
+
+                    Ok(count)
+                })
+                .await
         })
     }
 
-    pub fn search(
+    pub fn remaining_files_to_index_for_project(
+        &self,
+        project: &ModelHandle<Project>,
+    ) -> Option<usize> {
+        Some(
+            self.projects
+                .get(&project.downgrade())?
+                .outstanding_jobs
+                .lock()
+                .len(),
+        )
+    }
+
+    pub fn search_project(
         &mut self,
         project: ModelHandle<Project>,
         phrase: String,
@@ -682,110 +669,16 @@ impl VectorStore {
             })
         })
     }
-
-    fn project_entries_changed(
-        &mut self,
-        project: ModelHandle<Project>,
-        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
-        cx: &mut ModelContext<'_, VectorStore>,
-        worktree_id: &WorktreeId,
-    ) -> Option<()> {
-        let reindexing_delay = settings::get::<VectorStoreSettings>(cx).reindexing_delay_seconds;
-
-        let worktree = project
-            .read(cx)
-            .worktree_for_id(worktree_id.clone(), cx)?
-            .read(cx)
-            .snapshot();
-
-        let worktree_db_id = self
-            .projects
-            .get(&project.downgrade())?
-            .db_id_for_worktree_id(worktree.id())?;
-        let file_mtimes = self.get_file_mtimes(worktree_db_id);
-
-        let language_registry = self.language_registry.clone();
-
-        cx.spawn(|this, mut cx| async move {
-            let file_mtimes = file_mtimes.await.log_err()?;
-
-            for change in changes.into_iter() {
-                let change_path = change.0.clone();
-                let absolute_path = worktree.absolutize(&change_path);
-
-                // Skip if git ignored or symlink
-                if let Some(entry) = worktree.entry_for_id(change.1) {
-                    if entry.is_ignored || entry.is_symlink || entry.is_external {
-                        continue;
-                    }
-                }
-
-                match change.2 {
-                    PathChange::Removed => this.update(&mut cx, |this, _| {
-                        this.db_update_tx
-                            .try_send(DbOperation::Delete {
-                                worktree_id: worktree_db_id,
-                                path: absolute_path,
-                            })
-                            .unwrap();
-                    }),
-                    _ => {
-                        if let Ok(language) = language_registry
-                            .language_for_file(&change_path.to_path_buf(), None)
-                            .await
-                        {
-                            if language
-                                .grammar()
-                                .and_then(|grammar| grammar.embedding_config.as_ref())
-                                .is_none()
-                            {
-                                continue;
-                            }
-
-                            let modified_time =
-                                change_path.metadata().log_err()?.modified().log_err()?;
-
-                            let existing_time = file_mtimes.get(&change_path.to_path_buf());
-                            let already_stored = existing_time
-                                .map_or(false, |existing_time| &modified_time != existing_time);
-
-                            if !already_stored {
-                                this.update(&mut cx, |this, _| {
-                                    let reindex_time = modified_time
-                                        + Duration::from_secs(reindexing_delay as u64);
-
-                                    let project_state =
-                                        this.projects.get_mut(&project.downgrade())?;
-                                    project_state.update_pending_files(
-                                        PendingFile {
-                                            relative_path: change_path.to_path_buf(),
-                                            absolute_path,
-                                            modified_time,
-                                            worktree_db_id,
-                                            language: language.clone(),
-                                        },
-                                        reindex_time,
-                                    );
-
-                                    for file in project_state.get_outstanding_files() {
-                                        this.parsing_files_tx.try_send(file).unwrap();
-                                    }
-                                    Some(())
-                                });
-                            }
-                        }
-                    }
-                }
-            }
-
-            Some(())
-        })
-        .detach();
-
-        Some(())
-    }
 }
 
 impl Entity for VectorStore {
     type Event = ();
 }
+
+impl Drop for JobHandle {
+    fn drop(&mut self) {
+        if let Some(set) = self.set.upgrade() {
+            set.lock().remove(&self.id);
+        }
+    }
+}

crates/vector_store/src/vector_store_tests.rs 🔗

@@ -9,11 +9,17 @@ use anyhow::Result;
 use async_trait::async_trait;
 use gpui::{Task, TestAppContext};
 use language::{Language, LanguageConfig, LanguageRegistry};
-use project::{project_settings::ProjectSettings, FakeFs, Project};
+use project::{project_settings::ProjectSettings, FakeFs, Fs, Project};
 use rand::{rngs::StdRng, Rng};
 use serde_json::json;
 use settings::SettingsStore;
-use std::{path::Path, sync::Arc};
+use std::{
+    path::Path,
+    sync::{
+        atomic::{self, AtomicUsize},
+        Arc,
+    },
+};
 use unindent::Unindent;
 
 #[ctor::ctor]
@@ -62,29 +68,37 @@ async fn test_vector_store(cx: &mut TestAppContext) {
     let db_dir = tempdir::TempDir::new("vector-store").unwrap();
     let db_path = db_dir.path().join("db.sqlite");
 
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let store = VectorStore::new(
         fs.clone(),
         db_path,
-        Arc::new(FakeEmbeddingProvider),
+        embedding_provider.clone(),
         languages,
         cx.to_async(),
     )
     .await
     .unwrap();
 
-    let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
+    let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
     let worktree_id = project.read_with(cx, |project, cx| {
         project.worktrees(cx).next().unwrap().read(cx).id()
     });
-    store
-        .update(cx, |store, cx| store.add_project(project.clone(), cx))
+    let file_count = store
+        .update(cx, |store, cx| store.index_project(project.clone(), cx))
         .await
         .unwrap();
+    assert_eq!(file_count, 2);
     cx.foreground().run_until_parked();
+    store.update(cx, |store, _cx| {
+        assert_eq!(
+            store.remaining_files_to_index_for_project(&project),
+            Some(0)
+        );
+    });
 
     let search_results = store
         .update(cx, |store, cx| {
-            store.search(project.clone(), "aaaa".to_string(), 5, cx)
+            store.search_project(project.clone(), "aaaa".to_string(), 5, cx)
         })
         .await
         .unwrap();
@@ -92,10 +106,45 @@ async fn test_vector_store(cx: &mut TestAppContext) {
     assert_eq!(search_results[0].byte_range.start, 0);
     assert_eq!(search_results[0].name, "aaa");
     assert_eq!(search_results[0].worktree_id, worktree_id);
+
+    fs.save(
+        "/the-root/src/file2.rs".as_ref(),
+        &"
+            fn dddd() { println!(\"ddddd!\"); }
+            struct pqpqpqp {}
+        "
+        .unindent()
+        .into(),
+        Default::default(),
+    )
+    .await
+    .unwrap();
+
+    cx.foreground().run_until_parked();
+
+    let prev_embedding_count = embedding_provider.embedding_count();
+    let file_count = store
+        .update(cx, |store, cx| store.index_project(project.clone(), cx))
+        .await
+        .unwrap();
+    assert_eq!(file_count, 1);
+
+    cx.foreground().run_until_parked();
+    store.update(cx, |store, _cx| {
+        assert_eq!(
+            store.remaining_files_to_index_for_project(&project),
+            Some(0)
+        );
+    });
+
+    assert_eq!(
+        embedding_provider.embedding_count() - prev_embedding_count,
+        2
+    );
 }
 
 #[gpui::test]
-async fn test_code_context_retrieval(cx: &mut TestAppContext) {
+async fn test_code_context_retrieval() {
     let language = rust_lang();
     let mut retriever = CodeContextRetriever::new();
 
@@ -181,11 +230,22 @@ fn test_dot_product(mut rng: StdRng) {
     }
 }
 
-struct FakeEmbeddingProvider;
+#[derive(Default)]
+struct FakeEmbeddingProvider {
+    embedding_count: AtomicUsize,
+}
+
+impl FakeEmbeddingProvider {
+    fn embedding_count(&self) -> usize {
+        self.embedding_count.load(atomic::Ordering::SeqCst)
+    }
+}
 
 #[async_trait]
 impl EmbeddingProvider for FakeEmbeddingProvider {
     async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
+        self.embedding_count
+            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
         Ok(spans
             .iter()
             .map(|span| {