Move retrieval of embeddings from the db into `reindex_changed_files`

Antonio Scandurra and Kyle Caverly created

Co-Authored-By: Kyle Caverly <kyle@zed.dev>

Change summary

crates/semantic_index/src/semantic_index.rs | 220 ++++++++++------------
1 file changed, 100 insertions(+), 120 deletions(-)

Detailed changes

crates/semantic_index/src/semantic_index.rs 🔗

@@ -418,30 +418,12 @@ impl SemanticIndex {
         };
         worktree_state.paths_changed(changes, Instant::now(), worktree);
         if let WorktreeState::Registered(worktree_state) = worktree_state {
-            let embeddings_for_digest = {
-                let worktree_paths = worktree_state
-                    .changed_paths
-                    .iter()
-                    .map(|(path, _)| path.clone())
-                    .collect::<Vec<_>>();
-                let mut worktree_id_file_paths = HashMap::default();
-                worktree_id_file_paths.insert(worktree_state.db_id, worktree_paths);
-                self.db.embeddings_for_files(worktree_id_file_paths)
-            };
-
             cx.spawn_weak(|this, mut cx| async move {
-                let embeddings_for_digest =
-                    embeddings_for_digest.await.log_err().unwrap_or_default();
-
                 cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
                 if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
-                    Self::reindex_changed_paths(
-                        this,
-                        project,
-                        Some(change_time),
-                        &mut cx,
-                        Arc::new(embeddings_for_digest),
-                    )
+                    this.update(&mut cx, |this, cx| {
+                        this.reindex_changed_paths(project, Some(change_time), cx)
+                    })
                     .await;
                 }
             })
@@ -644,31 +626,10 @@ impl SemanticIndex {
             return Task::ready(Err(anyhow!("project was not registered")));
         };
         let outstanding_job_count_rx = project_state.outstanding_job_count_rx.clone();
-
-        let mut worktree_id_file_paths = HashMap::default();
-        for worktree in project_state.worktrees.values() {
-            if let WorktreeState::Registered(worktree_state) = worktree {
-                for (path, _) in &worktree_state.changed_paths {
-                    worktree_id_file_paths
-                        .entry(worktree_state.db_id)
-                        .or_insert(Vec::new())
-                        .push(path.clone());
-                }
-            }
-        }
-
         cx.spawn(|this, mut cx| async move {
-            let embeddings_for_digest = this.read_with(&cx, |this, _| {
-                this.db.embeddings_for_files(worktree_id_file_paths)
-            });
-            let embeddings_for_digest = Arc::new(embeddings_for_digest.await?);
-            Self::reindex_changed_paths(
-                this.clone(),
-                project.clone(),
-                None,
-                &mut cx,
-                embeddings_for_digest,
-            )
+            this.update(&mut cx, |this, cx| {
+                this.reindex_changed_paths(project.clone(), None, cx)
+            })
             .await;
             let count = *outstanding_job_count_rx.borrow();
             Ok((count, outstanding_job_count_rx))
@@ -822,94 +783,113 @@ impl SemanticIndex {
         })
     }
 
-    async fn reindex_changed_paths(
-        this: ModelHandle<SemanticIndex>,
+    fn reindex_changed_paths(
+        &mut self,
         project: ModelHandle<Project>,
         last_changed_before: Option<Instant>,
-        cx: &mut AsyncAppContext,
-        embeddings_for_digest: Arc<HashMap<DocumentDigest, Embedding>>,
-    ) {
+        cx: &mut ModelContext<Self>,
+    ) -> Task<()> {
+        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
+        {
+            project_state
+        } else {
+            return Task::ready(());
+        };
+
         let mut pending_files = Vec::new();
         let mut files_to_delete = Vec::new();
-        let (db, language_registry, parsing_files_tx) = this.update(cx, |this, cx| {
-            if let Some(project_state) = this.projects.get_mut(&project.downgrade()) {
-                let outstanding_job_count_tx = &project_state.outstanding_job_count_tx;
-                project_state
-                    .worktrees
-                    .retain(|worktree_id, worktree_state| {
-                        let worktree = if let Some(worktree) =
-                            project.read(cx).worktree_for_id(*worktree_id, cx)
-                        {
-                            worktree
-                        } else {
-                            return false;
-                        };
-                        let worktree_state =
-                            if let WorktreeState::Registered(worktree_state) = worktree_state {
-                                worktree_state
-                            } else {
-                                return true;
-                            };
-
-                        worktree_state.changed_paths.retain(|path, info| {
-                            if let Some(last_changed_before) = last_changed_before {
-                                if info.changed_at > last_changed_before {
-                                    return true;
-                                }
-                            }
-
-                            if info.is_deleted {
-                                files_to_delete.push((worktree_state.db_id, path.clone()));
-                            } else {
-                                let absolute_path = worktree.read(cx).absolutize(path);
-                                let job_handle = JobHandle::new(&outstanding_job_count_tx);
-                                pending_files.push(PendingFile {
-                                    absolute_path,
-                                    relative_path: path.clone(),
-                                    language: None,
-                                    job_handle,
-                                    modified_time: info.mtime,
-                                    worktree_db_id: worktree_state.db_id,
-                                });
-                            }
+        let outstanding_job_count_tx = &project_state.outstanding_job_count_tx;
+        project_state
+            .worktrees
+            .retain(|worktree_id, worktree_state| {
+                let worktree =
+                    if let Some(worktree) = project.read(cx).worktree_for_id(*worktree_id, cx) {
+                        worktree
+                    } else {
+                        return false;
+                    };
+                let worktree_state =
+                    if let WorktreeState::Registered(worktree_state) = worktree_state {
+                        worktree_state
+                    } else {
+                        return true;
+                    };
+
+                worktree_state.changed_paths.retain(|path, info| {
+                    if let Some(last_changed_before) = last_changed_before {
+                        if info.changed_at > last_changed_before {
+                            return true;
+                        }
+                    }
 
-                            false
+                    if info.is_deleted {
+                        files_to_delete.push((worktree_state.db_id, path.clone()));
+                    } else {
+                        let absolute_path = worktree.read(cx).absolutize(path);
+                        let job_handle = JobHandle::new(&outstanding_job_count_tx);
+                        pending_files.push(PendingFile {
+                            absolute_path,
+                            relative_path: path.clone(),
+                            language: None,
+                            job_handle,
+                            modified_time: info.mtime,
+                            worktree_db_id: worktree_state.db_id,
                         });
-                        true
-                    });
-            }
+                    }
 
-            (
-                this.db.clone(),
-                this.language_registry.clone(),
-                this.parsing_files_tx.clone(),
-            )
-        });
+                    false
+                });
+                true
+            });
 
-        for (worktree_db_id, path) in files_to_delete {
-            db.delete_file(worktree_db_id, path).await.log_err();
+        let mut worktree_id_file_paths = HashMap::default();
+        for worktree in project_state.worktrees.values() {
+            if let WorktreeState::Registered(worktree_state) = worktree {
+                for (path, _) in &worktree_state.changed_paths {
+                    worktree_id_file_paths
+                        .entry(worktree_state.db_id)
+                        .or_insert(Vec::new())
+                        .push(path.clone());
+                }
+            }
         }
 
-        for mut pending_file in pending_files {
-            if let Ok(language) = language_registry
-                .language_for_file(&pending_file.relative_path, None)
-                .await
-            {
-                if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
-                    && &language.name().as_ref() != &"Markdown"
-                    && language
-                        .grammar()
-                        .and_then(|grammar| grammar.embedding_config.as_ref())
-                        .is_none()
+        let db = self.db.clone();
+        let language_registry = self.language_registry.clone();
+        let parsing_files_tx = self.parsing_files_tx.clone();
+        cx.background().spawn(async move {
+            for (worktree_db_id, path) in files_to_delete {
+                db.delete_file(worktree_db_id, path).await.log_err();
+            }
+
+            let embeddings_for_digest = Arc::new(
+                db.embeddings_for_files(worktree_id_file_paths)
+                    .await
+                    .log_err()
+                    .unwrap_or_default(),
+            );
+
+            for mut pending_file in pending_files {
+                if let Ok(language) = language_registry
+                    .language_for_file(&pending_file.relative_path, None)
+                    .await
                 {
-                    continue;
+                    if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
+                        && &language.name().as_ref() != &"Markdown"
+                        && language
+                            .grammar()
+                            .and_then(|grammar| grammar.embedding_config.as_ref())
+                            .is_none()
+                    {
+                        continue;
+                    }
+                    pending_file.language = Some(language);
                 }
-                pending_file.language = Some(language);
+                parsing_files_tx
+                    .try_send((embeddings_for_digest.clone(), pending_file))
+                    .ok();
             }
-            parsing_files_tx
-                .try_send((embeddings_for_digest.clone(), pending_file))
-                .ok();
-        }
+        })
     }
 }