Ensure `SemanticIndex::search` waits for indexing to complete

Antonio Scandurra created

Change summary

crates/search/src/project_search.rs               |   4 
crates/semantic_index/src/semantic_index.rs       | 490 ++++++++++------
crates/semantic_index/src/semantic_index_tests.rs |  29 
3 files changed, 306 insertions(+), 217 deletions(-)

Detailed changes

crates/search/src/project_search.rs 🔗

@@ -635,7 +635,9 @@ impl ProjectSearchView {
             let project = self.model.read(cx).project.clone();
 
             let mut pending_file_count_rx = semantic_index.update(cx, |semantic_index, cx| {
-                semantic_index.index_project(project.clone(), cx);
+                semantic_index
+                    .index_project(project.clone(), cx)
+                    .detach_and_log_err(cx);
                 semantic_index.pending_file_count(&project).unwrap()
             });
 

crates/semantic_index/src/semantic_index.rs 🔗

@@ -13,7 +13,7 @@ use collections::{BTreeMap, HashMap, HashSet};
 use db::VectorDatabase;
 use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
 use embedding_queue::{EmbeddingQueue, FileToEmbed};
-use futures::{FutureExt, StreamExt};
+use futures::{future, FutureExt, StreamExt};
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
 use parking_lot::Mutex;
@@ -23,6 +23,7 @@ use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Work
 use smol::channel;
 use std::{
     cmp::Ordering,
+    future::Future,
     ops::Range,
     path::{Path, PathBuf},
     sync::{Arc, Weak},
@@ -32,7 +33,7 @@ use util::{
     channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
     http::HttpClient,
     paths::EMBEDDINGS_DIR,
-    ResultExt, TryFutureExt,
+    ResultExt,
 };
 
 const SEMANTIC_INDEX_VERSION: usize = 9;
@@ -132,7 +133,21 @@ impl WorktreeState {
 
 struct RegisteringWorktreeState {
     changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
-    _registration: Task<Option<()>>,
+    done_rx: watch::Receiver<Option<()>>,
+    _registration: Task<()>,
+}
+
+impl RegisteringWorktreeState {
+    fn done(&self) -> impl Future<Output = ()> {
+        let mut done_rx = self.done_rx.clone();
+        async move {
+            while let Some(result) = done_rx.next().await {
+                if result.is_some() {
+                    break;
+                }
+            }
+        }
+    }
 }
 
 struct RegisteredWorktreeState {
@@ -173,13 +188,6 @@ impl ProjectState {
         }
     }
 
-    fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
-        match self.worktrees.get(&id)? {
-            WorktreeState::Registering(_) => None,
-            WorktreeState::Registered(state) => Some(state.db_id),
-        }
-    }
-
     fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
         self.worktrees
             .iter()
@@ -188,10 +196,6 @@ impl ProjectState {
                 _ => None,
             })
     }
-
-    fn worktree(&mut self, id: WorktreeId) -> Option<&mut WorktreeState> {
-        self.worktrees.get_mut(&id)
-    }
 }
 
 #[derive(Clone)]
@@ -390,17 +394,20 @@ impl SemanticIndex {
         };
 
         let worktree = worktree.read(cx);
-        let worktree_state = if let Some(worktree_state) = project_state.worktree(worktree_id) {
-            worktree_state
-        } else {
-            return;
-        };
+        let worktree_state =
+            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
+                worktree_state
+            } else {
+                return;
+            };
         worktree_state.paths_changed(changes, worktree);
         if let WorktreeState::Registered(_) = worktree_state {
             cx.spawn_weak(|this, mut cx| async move {
                 cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
                 if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
-                    this.update(&mut cx, |this, cx| this.index_project(project, cx));
+                    this.update(&mut cx, |this, cx| {
+                        this.index_project(project, cx).detach_and_log_err(cx)
+                    });
                 }
             })
             .detach();
@@ -429,109 +436,126 @@ impl SemanticIndex {
         let worktree_id = worktree.id();
         let db = self.db.clone();
         let language_registry = self.language_registry.clone();
+        let (mut done_tx, done_rx) = watch::channel();
         let registration = cx.spawn(|this, mut cx| {
             async move {
-                scan_complete.await;
-                let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
-                let mut file_mtimes = db.get_file_mtimes(db_id).await?;
-                let worktree = if let Some(project) = project.upgrade(&cx) {
-                    project
-                        .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
-                        .ok_or_else(|| anyhow!("worktree not found"))?
-                } else {
-                    return anyhow::Ok(());
-                };
-                let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
-                let mut changed_paths = cx
-                    .background()
-                    .spawn(async move {
-                        let mut changed_paths = BTreeMap::new();
-                        for file in worktree.files(false, 0) {
-                            let absolute_path = worktree.absolutize(&file.path);
-
-                            if file.is_external || file.is_ignored || file.is_symlink {
-                                continue;
-                            }
-
-                            if let Ok(language) = language_registry
-                                .language_for_file(&absolute_path, None)
-                                .await
-                            {
-                                // Test if file is valid parseable file
-                                if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
-                                    && &language.name().as_ref() != &"Markdown"
-                                    && language
-                                        .grammar()
-                                        .and_then(|grammar| grammar.embedding_config.as_ref())
-                                        .is_none()
-                                {
+                let register = async {
+                    scan_complete.await;
+                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
+                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
+                    let worktree = if let Some(project) = project.upgrade(&cx) {
+                        project
+                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
+                            .ok_or_else(|| anyhow!("worktree not found"))?
+                    } else {
+                        return anyhow::Ok(());
+                    };
+                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
+                    let mut changed_paths = cx
+                        .background()
+                        .spawn(async move {
+                            let mut changed_paths = BTreeMap::new();
+                            for file in worktree.files(false, 0) {
+                                let absolute_path = worktree.absolutize(&file.path);
+
+                                if file.is_external || file.is_ignored || file.is_symlink {
                                     continue;
                                 }
 
-                                let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
-                                let already_stored = stored_mtime
-                                    .map_or(false, |existing_mtime| existing_mtime == file.mtime);
-
-                                if !already_stored {
-                                    changed_paths.insert(
-                                        file.path.clone(),
-                                        ChangedPathInfo {
-                                            mtime: file.mtime,
-                                            is_deleted: false,
-                                        },
-                                    );
+                                if let Ok(language) = language_registry
+                                    .language_for_file(&absolute_path, None)
+                                    .await
+                                {
+                                    // Test if file is valid parseable file
+                                    if !PARSEABLE_ENTIRE_FILE_TYPES
+                                        .contains(&language.name().as_ref())
+                                        && &language.name().as_ref() != &"Markdown"
+                                        && language
+                                            .grammar()
+                                            .and_then(|grammar| grammar.embedding_config.as_ref())
+                                            .is_none()
+                                    {
+                                        continue;
+                                    }
+
+                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
+                                    let already_stored = stored_mtime
+                                        .map_or(false, |existing_mtime| {
+                                            existing_mtime == file.mtime
+                                        });
+
+                                    if !already_stored {
+                                        changed_paths.insert(
+                                            file.path.clone(),
+                                            ChangedPathInfo {
+                                                mtime: file.mtime,
+                                                is_deleted: false,
+                                            },
+                                        );
+                                    }
                                 }
                             }
-                        }
 
-                        // Clean up entries from database that are no longer in the worktree.
-                        for (path, mtime) in file_mtimes {
-                            changed_paths.insert(
-                                path.into(),
-                                ChangedPathInfo {
-                                    mtime,
-                                    is_deleted: true,
-                                },
-                            );
-                        }
+                            // Clean up entries from database that are no longer in the worktree.
+                            for (path, mtime) in file_mtimes {
+                                changed_paths.insert(
+                                    path.into(),
+                                    ChangedPathInfo {
+                                        mtime,
+                                        is_deleted: true,
+                                    },
+                                );
+                            }
 
-                        anyhow::Ok(changed_paths)
-                    })
-                    .await?;
-                this.update(&mut cx, |this, cx| {
-                    let project_state = this
-                        .projects
-                        .get_mut(&project)
-                        .ok_or_else(|| anyhow!("project not registered"))?;
-                    let project = project
-                        .upgrade(cx)
-                        .ok_or_else(|| anyhow!("project was dropped"))?;
-
-                    if let Some(WorktreeState::Registering(state)) =
-                        project_state.worktrees.remove(&worktree_id)
-                    {
-                        changed_paths.extend(state.changed_paths);
-                    }
-                    project_state.worktrees.insert(
-                        worktree_id,
-                        WorktreeState::Registered(RegisteredWorktreeState {
-                            db_id,
-                            changed_paths,
-                        }),
-                    );
-                    this.index_project(project, cx);
+                            anyhow::Ok(changed_paths)
+                        })
+                        .await?;
+                    this.update(&mut cx, |this, cx| {
+                        let project_state = this
+                            .projects
+                            .get_mut(&project)
+                            .ok_or_else(|| anyhow!("project not registered"))?;
+                        let project = project
+                            .upgrade(cx)
+                            .ok_or_else(|| anyhow!("project was dropped"))?;
+
+                        if let Some(WorktreeState::Registering(state)) =
+                            project_state.worktrees.remove(&worktree_id)
+                        {
+                            changed_paths.extend(state.changed_paths);
+                        }
+                        project_state.worktrees.insert(
+                            worktree_id,
+                            WorktreeState::Registered(RegisteredWorktreeState {
+                                db_id,
+                                changed_paths,
+                            }),
+                        );
+                        this.index_project(project, cx).detach_and_log_err(cx);
+
+                        anyhow::Ok(())
+                    })?;
 
                     anyhow::Ok(())
-                })?;
+                };
 
-                anyhow::Ok(())
+                if register.await.log_err().is_none() {
+                    // Stop tracking this worktree if the registration failed.
+                    this.update(&mut cx, |this, _| {
+                        this.projects.get_mut(&project).map(|project_state| {
+                            project_state.worktrees.remove(&worktree_id);
+                        });
+                    })
+                }
+
+                *done_tx.borrow_mut() = Some(());
             }
-            .log_err()
         });
         project_state.worktrees.insert(
             worktree_id,
             WorktreeState::Registering(RegisteringWorktreeState {
                 changed_paths: Default::default(),
+                done_rx,
                 _registration: registration,
             }),
         );
@@ -567,7 +591,7 @@ impl SemanticIndex {
         // Register new worktrees
         worktrees.retain(|worktree| {
             let worktree_id = worktree.read(cx).id();
-            project_state.worktree(worktree_id).is_none()
+            !project_state.worktrees.contains_key(&worktree_id)
         });
         for worktree in worktrees {
             self.register_worktree(project.clone(), worktree, cx);
@@ -595,25 +619,13 @@ impl SemanticIndex {
         excludes: Vec<PathMatcher>,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<SearchResult>>> {
-        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
-            state
-        } else {
-            return Task::ready(Err(anyhow!("project not added")));
-        };
-
-        let worktree_db_ids = project
-            .read(cx)
-            .worktrees(cx)
-            .filter_map(|worktree| {
-                let worktree_id = worktree.read(cx).id();
-                project_state.db_id_for_worktree_id(worktree_id)
-            })
-            .collect::<Vec<_>>();
-
+        let index = self.index_project(project.clone(), cx);
         let embedding_provider = self.embedding_provider.clone();
         let db_path = self.db.path().clone();
         let fs = self.fs.clone();
         cx.spawn(|this, mut cx| async move {
+            index.await?;
+
             let t0 = Instant::now();
             let database =
                 VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
@@ -630,6 +642,24 @@ impl SemanticIndex {
                 t0.elapsed().as_millis()
             );
 
+            let worktree_db_ids = this.read_with(&cx, |this, _| {
+                let project_state = this
+                    .projects
+                    .get(&project.downgrade())
+                    .ok_or_else(|| anyhow!("project was not indexed"))?;
+                let worktree_db_ids = project_state
+                    .worktrees
+                    .values()
+                    .filter_map(|worktree| {
+                        if let WorktreeState::Registered(worktree) = worktree {
+                            Some(worktree.db_id)
+                        } else {
+                            None
+                        }
+                    })
+                    .collect::<Vec<i64>>();
+                anyhow::Ok(worktree_db_ids)
+            })?;
             let file_ids = database
                 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
                 .await?;
@@ -723,7 +753,11 @@ impl SemanticIndex {
         })
     }
 
-    pub fn index_project(&mut self, project: ModelHandle<Project>, cx: &mut ModelContext<Self>) {
+    pub fn index_project(
+        &mut self,
+        project: ModelHandle<Project>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>> {
         if !self.projects.contains_key(&project.downgrade()) {
             log::trace!("Registering Project for Semantic Index");
 
@@ -740,96 +774,152 @@ impl SemanticIndex {
                 .insert(project.downgrade(), ProjectState::new(subscription));
             self.project_worktrees_changed(project.clone(), cx);
         }
+        let project_state = self.projects.get(&project.downgrade()).unwrap();
+        let mut outstanding_job_count_rx = project_state.outstanding_job_count_rx.clone();
 
-        let project_state = self.projects.get_mut(&project.downgrade()).unwrap();
+        let db = self.db.clone();
+        let language_registry = self.language_registry.clone();
+        let parsing_files_tx = self.parsing_files_tx.clone();
+        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
 
-        let mut pending_files = Vec::new();
-        let mut files_to_delete = Vec::new();
-        let outstanding_job_count_tx = &project_state.outstanding_job_count_tx;
-        project_state
-            .worktrees
-            .retain(|worktree_id, worktree_state| {
-                let worktree =
-                    if let Some(worktree) = project.read(cx).worktree_for_id(*worktree_id, cx) {
-                        worktree
-                    } else {
-                        return false;
-                    };
-                let worktree_state =
-                    if let WorktreeState::Registered(worktree_state) = worktree_state {
-                        worktree_state
-                    } else {
-                        return true;
-                    };
+        cx.spawn(|this, mut cx| async move {
+            worktree_registration.await?;
+
+            let mut pending_files = Vec::new();
+            let mut files_to_delete = Vec::new();
+            this.update(&mut cx, |this, cx| {
+                let project_state = this
+                    .projects
+                    .get_mut(&project.downgrade())
+                    .ok_or_else(|| anyhow!("project was dropped"))?;
+                let outstanding_job_count_tx = &project_state.outstanding_job_count_tx;
+
+                project_state
+                    .worktrees
+                    .retain(|worktree_id, worktree_state| {
+                        let worktree = if let Some(worktree) =
+                            project.read(cx).worktree_for_id(*worktree_id, cx)
+                        {
+                            worktree
+                        } else {
+                            return false;
+                        };
+                        let worktree_state =
+                            if let WorktreeState::Registered(worktree_state) = worktree_state {
+                                worktree_state
+                            } else {
+                                return true;
+                            };
+
+                        worktree_state.changed_paths.retain(|path, info| {
+                            if info.is_deleted {
+                                files_to_delete.push((worktree_state.db_id, path.clone()));
+                            } else {
+                                let absolute_path = worktree.read(cx).absolutize(path);
+                                let job_handle = JobHandle::new(outstanding_job_count_tx);
+                                pending_files.push(PendingFile {
+                                    absolute_path,
+                                    relative_path: path.clone(),
+                                    language: None,
+                                    job_handle,
+                                    modified_time: info.mtime,
+                                    worktree_db_id: worktree_state.db_id,
+                                });
+                            }
 
-                worktree_state.changed_paths.retain(|path, info| {
-                    if info.is_deleted {
-                        files_to_delete.push((worktree_state.db_id, path.clone()));
-                    } else {
-                        let absolute_path = worktree.read(cx).absolutize(path);
-                        let job_handle = JobHandle::new(&outstanding_job_count_tx);
-                        pending_files.push(PendingFile {
-                            absolute_path,
-                            relative_path: path.clone(),
-                            language: None,
-                            job_handle,
-                            modified_time: info.mtime,
-                            worktree_db_id: worktree_state.db_id,
+                            false
                         });
-                    }
-
-                    false
-                });
-                true
-            });
+                        true
+                    });
 
-        let db = self.db.clone();
-        let language_registry = self.language_registry.clone();
-        let parsing_files_tx = self.parsing_files_tx.clone();
-        cx.background()
-            .spawn(async move {
-                for (worktree_db_id, path) in files_to_delete {
-                    db.delete_file(worktree_db_id, path).await.log_err();
-                }
+                anyhow::Ok(())
+            })?;
 
-                let embeddings_for_digest = {
-                    let mut files = HashMap::default();
-                    for pending_file in &pending_files {
-                        files
-                            .entry(pending_file.worktree_db_id)
-                            .or_insert(Vec::new())
-                            .push(pending_file.relative_path.clone());
+            cx.background()
+                .spawn(async move {
+                    for (worktree_db_id, path) in files_to_delete {
+                        db.delete_file(worktree_db_id, path).await.log_err();
                     }
-                    Arc::new(
-                        db.embeddings_for_files(files)
-                            .await
-                            .log_err()
-                            .unwrap_or_default(),
-                    )
-                };
 
-                for mut pending_file in pending_files {
-                    if let Ok(language) = language_registry
-                        .language_for_file(&pending_file.relative_path, None)
-                        .await
-                    {
-                        if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
-                            && &language.name().as_ref() != &"Markdown"
-                            && language
-                                .grammar()
-                                .and_then(|grammar| grammar.embedding_config.as_ref())
-                                .is_none()
+                    let embeddings_for_digest = {
+                        let mut files = HashMap::default();
+                        for pending_file in &pending_files {
+                            files
+                                .entry(pending_file.worktree_db_id)
+                                .or_insert(Vec::new())
+                                .push(pending_file.relative_path.clone());
+                        }
+                        Arc::new(
+                            db.embeddings_for_files(files)
+                                .await
+                                .log_err()
+                                .unwrap_or_default(),
+                        )
+                    };
+
+                    for mut pending_file in pending_files {
+                        if let Ok(language) = language_registry
+                            .language_for_file(&pending_file.relative_path, None)
+                            .await
                         {
-                            continue;
+                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
+                                && &language.name().as_ref() != &"Markdown"
+                                && language
+                                    .grammar()
+                                    .and_then(|grammar| grammar.embedding_config.as_ref())
+                                    .is_none()
+                            {
+                                continue;
+                            }
+                            pending_file.language = Some(language);
                         }
-                        pending_file.language = Some(language);
+                        parsing_files_tx
+                            .try_send((embeddings_for_digest.clone(), pending_file))
+                            .ok();
                     }
-                    parsing_files_tx
-                        .try_send((embeddings_for_digest.clone(), pending_file))
-                        .ok();
+
+                    // Wait until we're done indexing.
+                    while let Some(count) = outstanding_job_count_rx.next().await {
+                        if count == 0 {
+                            break;
+                        }
+                    }
+                })
+                .await;
+
+            Ok(())
+        })
+    }
+
+    fn wait_for_worktree_registration(
+        &self,
+        project: &ModelHandle<Project>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>> {
+        let project = project.downgrade();
+        cx.spawn_weak(|this, cx| async move {
+            loop {
+                let mut pending_worktrees = Vec::new();
+                this.upgrade(&cx)
+                    .ok_or_else(|| anyhow!("semantic index dropped"))?
+                    .read_with(&cx, |this, _| {
+                        if let Some(project) = this.projects.get(&project) {
+                            for worktree in project.worktrees.values() {
+                                if let WorktreeState::Registering(worktree) = worktree {
+                                    pending_worktrees.push(worktree.done());
+                                }
+                            }
+                        }
+                    });
+
+                if pending_worktrees.is_empty() {
+                    break;
+                } else {
+                    future::join_all(pending_worktrees).await;
                 }
-            })
-            .detach()
+            }
+            Ok(())
+        })
     }
 }
 

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -87,7 +87,16 @@ async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestApp
 
     let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
 
-    semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
+    let search_results = semantic_index.update(cx, |store, cx| {
+        store.search_project(
+            project.clone(),
+            "aaaaaabbbbzz".to_string(),
+            5,
+            vec![],
+            vec![],
+            cx,
+        )
+    });
     let pending_file_count =
         semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
     deterministic.run_until_parked();
@@ -95,20 +104,7 @@ async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestApp
     deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
     assert_eq!(*pending_file_count.borrow(), 0);
 
-    let search_results = semantic_index
-        .update(cx, |store, cx| {
-            store.search_project(
-                project.clone(),
-                "aaaaaabbbbzz".to_string(),
-                5,
-                vec![],
-                vec![],
-                cx,
-            )
-        })
-        .await
-        .unwrap();
-
+    let search_results = search_results.await.unwrap();
     assert_search_results(
         &search_results,
         &[
@@ -185,11 +181,12 @@ async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestApp
     deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 
     let prev_embedding_count = embedding_provider.embedding_count();
-    semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
+    let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
     deterministic.run_until_parked();
     assert_eq!(*pending_file_count.borrow(), 1);
     deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
     assert_eq!(*pending_file_count.borrow(), 0);
+    index.await.unwrap();
 
     assert_eq!(
         embedding_provider.embedding_count() - prev_embedding_count,