moved semantic search model to dev and preview only.

KCaverly and maxbrunsfeld created

moved db update tasks to long lived persistent task.

Co-authored-by: maxbrunsfeld <max@zed.dev>

Change summary

crates/project/src/project.rs                 |   5 
crates/vector_store/src/modal.rs              |   2 
crates/vector_store/src/vector_store.rs       | 332 ++++++++++++++------
crates/vector_store/src/vector_store_tests.rs |  25 
4 files changed, 241 insertions(+), 123 deletions(-)

Detailed changes

crates/project/src/project.rs 🔗

@@ -260,6 +260,7 @@ pub enum Event {
     ActiveEntryChanged(Option<ProjectEntryId>),
     WorktreeAdded,
     WorktreeRemoved(WorktreeId),
+    WorktreeUpdatedEntries(WorktreeId, UpdatedEntriesSet),
     DiskBasedDiagnosticsStarted {
         language_server_id: LanguageServerId,
     },
@@ -5371,6 +5372,10 @@ impl Project {
                     this.update_local_worktree_buffers(&worktree, changes, cx);
                     this.update_local_worktree_language_servers(&worktree, changes, cx);
                     this.update_local_worktree_settings(&worktree, changes, cx);
+                    cx.emit(Event::WorktreeUpdatedEntries(
+                        worktree.read(cx).id(),
+                        changes.clone(),
+                    ));
                 }
                 worktree::Event::UpdatedGitRepositories(updated_repos) => {
                     this.update_local_worktree_buffers_git_repos(worktree, updated_repos, cx)

crates/vector_store/src/modal.rs 🔗

@@ -124,7 +124,7 @@ impl PickerDelegate for SemanticSearchDelegate {
             if let Some(retrieved) = retrieved_cached.log_err() {
                 if !retrieved {
                     let task = vector_store.update(&mut cx, |store, cx| {
-                        store.search(&project, query.to_string(), 10, cx)
+                        store.search(project.clone(), query.to_string(), 10, cx)
                     });
 
                     if let Some(results) = task.await.log_err() {

crates/vector_store/src/vector_store.rs 🔗

@@ -8,7 +8,11 @@ mod vector_store_tests;
 use anyhow::{anyhow, Result};
 use db::VectorDatabase;
 use embedding::{EmbeddingProvider, OpenAIEmbeddings};
-use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task, ViewContext};
+use futures::{channel::oneshot, Future};
+use gpui::{
+    AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
+    WeakModelHandle,
+};
 use language::{Language, LanguageRegistry};
 use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 use project::{Fs, Project, WorktreeId};
@@ -22,7 +26,10 @@ use std::{
 };
 use tree_sitter::{Parser, QueryCursor};
 use util::{
-    channel::RELEASE_CHANNEL_NAME, http::HttpClient, paths::EMBEDDINGS_DIR, ResultExt, TryFutureExt,
+    channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
+    http::HttpClient,
+    paths::EMBEDDINGS_DIR,
+    ResultExt,
 };
 use workspace::{Workspace, WorkspaceCreated};
 
@@ -39,12 +46,16 @@ pub fn init(
     language_registry: Arc<LanguageRegistry>,
     cx: &mut AppContext,
 ) {
+    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
+        return;
+    }
+
     let db_file_path = EMBEDDINGS_DIR
         .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
         .join("embeddings_db");
 
-    let vector_store = cx.add_model(|_| {
-        VectorStore::new(
+    cx.spawn(move |mut cx| async move {
+        let vector_store = VectorStore::new(
             fs,
             db_file_path,
             // Arc::new(embedding::DummyEmbeddings {}),
@@ -52,42 +63,49 @@ pub fn init(
                 client: http_client,
             }),
             language_registry,
+            cx.clone(),
         )
-    });
-
-    cx.subscribe_global::<WorkspaceCreated, _>({
-        let vector_store = vector_store.clone();
-        move |event, cx| {
-            let workspace = &event.0;
-            if let Some(workspace) = workspace.upgrade(cx) {
-                let project = workspace.read(cx).project().clone();
-                if project.read(cx).is_local() {
-                    vector_store.update(cx, |store, cx| {
-                        store.add_project(project, cx).detach();
-                    });
+        .await?;
+
+        cx.update(|cx| {
+            cx.subscribe_global::<WorkspaceCreated, _>({
+                let vector_store = vector_store.clone();
+                move |event, cx| {
+                    let workspace = &event.0;
+                    if let Some(workspace) = workspace.upgrade(cx) {
+                        let project = workspace.read(cx).project().clone();
+                        if project.read(cx).is_local() {
+                            vector_store.update(cx, |store, cx| {
+                                store.add_project(project, cx).detach();
+                            });
+                        }
+                    }
                 }
-            }
-        }
-    })
-    .detach();
-
-    cx.add_action({
-        move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
-            let vector_store = vector_store.clone();
-            workspace.toggle_modal(cx, |workspace, cx| {
-                let project = workspace.project().clone();
-                let workspace = cx.weak_handle();
-                cx.add_view(|cx| {
-                    SemanticSearch::new(
-                        SemanticSearchDelegate::new(workspace, project, vector_store),
-                        cx,
-                    )
-                })
             })
-        }
-    });
+            .detach();
+
+            cx.add_action({
+                move |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
+                    let vector_store = vector_store.clone();
+                    workspace.toggle_modal(cx, |workspace, cx| {
+                        let project = workspace.project().clone();
+                        let workspace = cx.weak_handle();
+                        cx.add_view(|cx| {
+                            SemanticSearch::new(
+                                SemanticSearchDelegate::new(workspace, project, vector_store),
+                                cx,
+                            )
+                        })
+                    })
+                }
+            });
+
+            SemanticSearch::init(cx);
+        });
 
-    SemanticSearch::init(cx);
+        anyhow::Ok(())
+    })
+    .detach();
 }
 
 #[derive(Debug)]
@@ -102,7 +120,14 @@ pub struct VectorStore {
     database_url: Arc<PathBuf>,
     embedding_provider: Arc<dyn EmbeddingProvider>,
     language_registry: Arc<LanguageRegistry>,
+    db_update_tx: channel::Sender<DbWrite>,
+    _db_update_task: Task<()>,
+    projects: HashMap<WeakModelHandle<Project>, ProjectState>,
+}
+
+struct ProjectState {
     worktree_db_ids: Vec<(WorktreeId, i64)>,
+    _subscription: gpui::Subscription,
 }
 
 #[derive(Debug, Clone)]
@@ -113,20 +138,81 @@ pub struct SearchResult {
     pub file_path: PathBuf,
 }
 
+enum DbWrite {
+    InsertFile {
+        worktree_id: i64,
+        indexed_file: IndexedFile,
+    },
+    Delete {
+        worktree_id: i64,
+        path: PathBuf,
+    },
+    FindOrCreateWorktree {
+        path: PathBuf,
+        sender: oneshot::Sender<Result<i64>>,
+    },
+}
+
 impl VectorStore {
-    fn new(
+    async fn new(
         fs: Arc<dyn Fs>,
         database_url: PathBuf,
         embedding_provider: Arc<dyn EmbeddingProvider>,
         language_registry: Arc<LanguageRegistry>,
-    ) -> Self {
-        Self {
-            fs,
-            database_url: Arc::new(database_url),
-            embedding_provider,
-            language_registry,
-            worktree_db_ids: Vec::new(),
-        }
+        mut cx: AsyncAppContext,
+    ) -> Result<ModelHandle<Self>> {
+        let database_url = Arc::new(database_url);
+
+        let db = cx
+            .background()
+            .spawn({
+                let fs = fs.clone();
+                let database_url = database_url.clone();
+                async move {
+                    if let Some(db_directory) = database_url.parent() {
+                        fs.create_dir(db_directory).await.log_err();
+                    }
+
+                    let db = VectorDatabase::new(database_url.to_string_lossy().to_string())?;
+                    anyhow::Ok(db)
+                }
+            })
+            .await?;
+
+        Ok(cx.add_model(|cx| {
+            let (db_update_tx, db_update_rx) = channel::unbounded();
+            let _db_update_task = cx.background().spawn(async move {
+                while let Ok(job) = db_update_rx.recv().await {
+                    match job {
+                        DbWrite::InsertFile {
+                            worktree_id,
+                            indexed_file,
+                        } => {
+                            log::info!("Inserting File: {:?}", &indexed_file.path);
+                            db.insert_file(worktree_id, indexed_file).log_err();
+                        }
+                        DbWrite::Delete { worktree_id, path } => {
+                            log::info!("Deleting File: {:?}", &path);
+                            db.delete_file(worktree_id, path).log_err();
+                        }
+                        DbWrite::FindOrCreateWorktree { path, sender } => {
+                            let id = db.find_or_create_worktree(&path);
+                            sender.send(id).ok();
+                        }
+                    }
+                }
+            });
+
+            Self {
+                fs,
+                database_url,
+                db_update_tx,
+                embedding_provider,
+                language_registry,
+                projects: HashMap::new(),
+                _db_update_task,
+            }
+        }))
     }
 
     async fn index_file(
@@ -196,6 +282,14 @@ impl VectorStore {
         });
     }
 
+    fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
+        let (tx, rx) = oneshot::channel();
+        self.db_update_tx
+            .try_send(DbWrite::FindOrCreateWorktree { path, sender: tx })
+            .unwrap();
+        async move { rx.await? }
+    }
+
     fn add_project(
         &mut self,
         project: ModelHandle<Project>,
@@ -211,19 +305,28 @@ impl VectorStore {
                 }
             })
             .collect::<Vec<_>>();
+        let worktree_db_ids = project
+            .read(cx)
+            .worktrees(cx)
+            .map(|worktree| {
+                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
+            })
+            .collect::<Vec<_>>();
 
         let fs = self.fs.clone();
         let language_registry = self.language_registry.clone();
         let embedding_provider = self.embedding_provider.clone();
         let database_url = self.database_url.clone();
+        let db_update_tx = self.db_update_tx.clone();
 
         cx.spawn(|this, mut cx| async move {
             futures::future::join_all(worktree_scans_complete).await;
 
+            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
+
             if let Some(db_directory) = database_url.parent() {
                 fs.create_dir(db_directory).await.log_err();
             }
-            let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
 
             let worktrees = project.read_with(&cx, |project, cx| {
                 project
@@ -234,32 +337,31 @@ impl VectorStore {
 
             // Here we query the worktree ids, and yet we dont have them elsewhere
             // We likely want to clean up these datastructures
-            let (db, mut worktree_file_times, worktree_db_ids) = cx
+            let (mut worktree_file_times, db_ids_by_worktree_id) = cx
                 .background()
                 .spawn({
                     let worktrees = worktrees.clone();
                     async move {
-                        let mut worktree_db_ids: HashMap<WorktreeId, i64> = HashMap::new();
+                        let db = VectorDatabase::new(database_url.to_string_lossy().into())?;
+                        let mut db_ids_by_worktree_id = HashMap::new();
                         let mut file_times: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>> =
                             HashMap::new();
-                        for worktree in worktrees {
-                            let worktree_db_id =
-                                db.find_or_create_worktree(worktree.abs_path().as_ref())?;
-                            worktree_db_ids.insert(worktree.id(), worktree_db_id);
-                            file_times.insert(worktree.id(), db.get_file_mtimes(worktree_db_id)?);
+                        for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
+                            let db_id = db_id?;
+                            db_ids_by_worktree_id.insert(worktree.id(), db_id);
+                            file_times.insert(worktree.id(), db.get_file_mtimes(db_id)?);
                         }
-                        anyhow::Ok((db, file_times, worktree_db_ids))
+                        anyhow::Ok((file_times, db_ids_by_worktree_id))
                     }
                 })
                 .await?;
 
             let (paths_tx, paths_rx) =
                 channel::unbounded::<(i64, PathBuf, Arc<Language>, SystemTime)>();
-            let (delete_paths_tx, delete_paths_rx) = channel::unbounded::<(i64, PathBuf)>();
-            let (indexed_files_tx, indexed_files_rx) = channel::unbounded::<(i64, IndexedFile)>();
             cx.background()
                 .spawn({
-                    let worktree_db_ids = worktree_db_ids.clone();
+                    let db_ids_by_worktree_id = db_ids_by_worktree_id.clone();
+                    let db_update_tx = db_update_tx.clone();
                     async move {
                         for worktree in worktrees.into_iter() {
                             let mut file_mtimes =
@@ -289,7 +391,7 @@ impl VectorStore {
                                     if !already_stored {
                                         paths_tx
                                             .try_send((
-                                                worktree_db_ids[&worktree.id()],
+                                                db_ids_by_worktree_id[&worktree.id()],
                                                 path_buf,
                                                 language,
                                                 file.mtime,
@@ -299,8 +401,11 @@ impl VectorStore {
                                 }
                             }
                             for file in file_mtimes.keys() {
-                                delete_paths_tx
-                                    .try_send((worktree_db_ids[&worktree.id()], file.to_owned()))
+                                db_update_tx
+                                    .try_send(DbWrite::Delete {
+                                        worktree_id: db_ids_by_worktree_id[&worktree.id()],
+                                        path: file.to_owned(),
+                                    })
                                     .unwrap();
                             }
                         }
@@ -308,25 +413,6 @@ impl VectorStore {
                 })
                 .detach();
 
-            let db_update_task = cx.background().spawn(
-                async move {
-                    // Inserting all new files
-                    while let Ok((worktree_id, indexed_file)) = indexed_files_rx.recv().await {
-                        log::info!("Inserting File: {:?}", &indexed_file.path);
-                        db.insert_file(worktree_id, indexed_file).log_err();
-                    }
-
-                    // Deleting all old files
-                    while let Ok((worktree_id, delete_path)) = delete_paths_rx.recv().await {
-                        log::info!("Deleting File: {:?}", &delete_path);
-                        db.delete_file(worktree_id, delete_path).log_err();
-                    }
-
-                    anyhow::Ok(())
-                }
-                .log_err(),
-            );
-
             cx.background()
                 .scoped(|scope| {
                     for _ in 0..cx.background().num_cpus() {
@@ -348,8 +434,11 @@ impl VectorStore {
                                 .await
                                 .log_err()
                                 {
-                                    indexed_files_tx
-                                        .try_send((worktree_id, indexed_file))
+                                    db_update_tx
+                                        .try_send(DbWrite::InsertFile {
+                                            worktree_id,
+                                            indexed_file,
+                                        })
                                         .unwrap();
                                 }
                             }
@@ -357,12 +446,22 @@ impl VectorStore {
                     }
                 })
                 .await;
-            drop(indexed_files_tx);
 
-            db_update_task.await;
-
-            this.update(&mut cx, |this, _| {
-                this.worktree_db_ids.extend(worktree_db_ids);
+            this.update(&mut cx, |this, cx| {
+                let _subscription = cx.subscribe(&project, |this, project, event, cx| {
+                    if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
+                        //
+                        log::info!("worktree changes {:?}", changes);
+                    }
+                });
+
+                this.projects.insert(
+                    project.downgrade(),
+                    ProjectState {
+                        worktree_db_ids: db_ids_by_worktree_id.into_iter().collect(),
+                        _subscription,
+                    },
+                );
             });
 
             log::info!("Semantic Indexing Complete!");
@@ -373,23 +472,32 @@ impl VectorStore {
 
     pub fn search(
         &mut self,
-        project: &ModelHandle<Project>,
+        project: ModelHandle<Project>,
         phrase: String,
         limit: usize,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<SearchResult>>> {
-        let project = project.read(cx);
+        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
+            state
+        } else {
+            return Task::ready(Err(anyhow!("project not added")));
+        };
+
         let worktree_db_ids = project
+            .read(cx)
             .worktrees(cx)
             .filter_map(|worktree| {
                 let worktree_id = worktree.read(cx).id();
-                self.worktree_db_ids.iter().find_map(|(id, db_id)| {
-                    if *id == worktree_id {
-                        Some(*db_id)
-                    } else {
-                        None
-                    }
-                })
+                project_state
+                    .worktree_db_ids
+                    .iter()
+                    .find_map(|(id, db_id)| {
+                        if *id == worktree_id {
+                            Some(*db_id)
+                        } else {
+                            None
+                        }
+                    })
             })
             .collect::<Vec<_>>();
 
@@ -428,17 +536,27 @@ impl VectorStore {
                 })
                 .await?;
 
-            let results = this.read_with(&cx, |this, _| {
-                documents
+            this.read_with(&cx, |this, _| {
+                let project_state = if let Some(state) = this.projects.get(&project.downgrade()) {
+                    state
+                } else {
+                    return Err(anyhow!("project not added"));
+                };
+
+                Ok(documents
                     .into_iter()
                     .filter_map(|(worktree_db_id, file_path, offset, name)| {
-                        let worktree_id = this.worktree_db_ids.iter().find_map(|(id, db_id)| {
-                            if *db_id == worktree_db_id {
-                                Some(*id)
-                            } else {
-                                None
-                            }
-                        })?;
+                        let worktree_id =
+                            project_state
+                                .worktree_db_ids
+                                .iter()
+                                .find_map(|(id, db_id)| {
+                                    if *db_id == worktree_db_id {
+                                        Some(*id)
+                                    } else {
+                                        None
+                                    }
+                                })?;
                         Some(SearchResult {
                             worktree_id,
                             name,
@@ -446,10 +564,8 @@ impl VectorStore {
                             file_path,
                         })
                     })
-                    .collect()
-            });
-
-            anyhow::Ok(results)
+                    .collect())
+            })
         })
     }
 }

crates/vector_store/src/vector_store_tests.rs 🔗

@@ -5,7 +5,7 @@ use anyhow::Result;
 use async_trait::async_trait;
 use gpui::{Task, TestAppContext};
 use language::{Language, LanguageConfig, LanguageRegistry};
-use project::{FakeFs, Project};
+use project::{FakeFs, Fs, Project};
 use rand::Rng;
 use serde_json::json;
 use unindent::Unindent;
@@ -60,14 +60,15 @@ async fn test_vector_store(cx: &mut TestAppContext) {
     let db_dir = tempdir::TempDir::new("vector-store").unwrap();
     let db_path = db_dir.path().join("db.sqlite");
 
-    let store = cx.add_model(|_| {
-        VectorStore::new(
-            fs.clone(),
-            db_path,
-            Arc::new(FakeEmbeddingProvider),
-            languages,
-        )
-    });
+    let store = VectorStore::new(
+        fs.clone(),
+        db_path,
+        Arc::new(FakeEmbeddingProvider),
+        languages,
+        cx.to_async(),
+    )
+    .await
+    .unwrap();
 
     let project = Project::test(fs, ["/the-root".as_ref()], cx).await;
     let worktree_id = project.read_with(cx, |project, cx| {
@@ -75,15 +76,11 @@ async fn test_vector_store(cx: &mut TestAppContext) {
     });
     let add_project = store.update(cx, |store, cx| store.add_project(project.clone(), cx));
 
-    // TODO - remove
-    cx.foreground()
-        .advance_clock(std::time::Duration::from_secs(3));
-
     add_project.await.unwrap();
 
     let search_results = store
         .update(cx, |store, cx| {
-            store.search(&project, "aaaa".to_string(), 5, cx)
+            store.search(project.clone(), "aaaa".to_string(), 5, cx)
         })
         .await
         .unwrap();