reworked ProjectState to include additional context

KCaverly created

Change summary

crates/search/src/project_search.rs               |  12 +
crates/semantic_index/src/embedding.rs            |   2 
crates/semantic_index/src/semantic_index.rs       | 138 +++++++++++++++-
crates/semantic_index/src/semantic_index_tests.rs |   7 
4 files changed, 146 insertions(+), 13 deletions(-)

Detailed changes

crates/search/src/project_search.rs 🔗

@@ -640,6 +640,7 @@ impl ProjectSearchView {
             self.search_options = SearchOptions::none();
 
             let project = self.model.read(cx).project.clone();
+
             let index_task = semantic_index.update(cx, |semantic_index, cx| {
                 semantic_index.index_project(project, cx)
             });
@@ -759,7 +760,7 @@ impl ProjectSearchView {
     }
 
     fn new(model: ModelHandle<ProjectSearch>, cx: &mut ViewContext<Self>) -> Self {
-        let project;
+        let mut project;
         let excerpts;
         let mut query_text = String::new();
         let mut options = SearchOptions::NONE;
@@ -843,6 +844,15 @@ impl ProjectSearchView {
         .detach();
         let filters_enabled = false;
 
+        // Initialize Semantic Index if Needed
+        if SemanticIndex::enabled(cx) {
+            let model = model.read(cx);
+            project = model.project.clone();
+            SemanticIndex::global(cx).map(|semantic| {
+                semantic.update(cx, |this, cx| this.initialize_project(project, cx))
+            });
+        }
+
         // Check if Worktrees have all been previously indexed
         let mut this = ProjectSearchView {
             search_id: model.read(cx).search_id,

crates/semantic_index/src/embedding.rs 🔗

@@ -39,7 +39,7 @@ struct OpenAIEmbeddingResponse {
 
 #[derive(Debug, Deserialize)]
 struct OpenAIEmbedding {
-    embedding: Vec<f16>,
+    embedding: Vec<f32>,
     index: usize,
     object: String,
 }

crates/semantic_index/src/semantic_index.rs 🔗

@@ -92,7 +92,8 @@ pub struct SemanticIndex {
 
 struct ProjectState {
     worktree_db_ids: Vec<(WorktreeId, i64)>,
-    file_mtimes: HashMap<PathBuf, SystemTime>,
+    worktree_file_mtimes: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>>,
+    subscription: gpui::Subscription,
     outstanding_job_count_rx: watch::Receiver<usize>,
     _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 }
@@ -113,6 +114,25 @@ impl JobHandle {
     }
 }
 impl ProjectState {
+    fn new(
+        subscription: gpui::Subscription,
+        worktree_db_ids: Vec<(WorktreeId, i64)>,
+        worktree_file_mtimes: HashMap<WorktreeId, HashMap<PathBuf, SystemTime>>,
+        outstanding_job_count_rx: watch::Receiver<usize>,
+        _outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
+    ) -> Self {
+        let (job_count_tx, job_count_rx) = watch::channel_with(0);
+        let job_count_tx = Arc::new(Mutex::new(job_count_tx));
+
+        Self {
+            worktree_db_ids,
+            worktree_file_mtimes,
+            outstanding_job_count_rx,
+            _outstanding_job_count_tx,
+            subscription,
+        }
+    }
+
     fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
         self.worktree_db_ids
             .iter()
@@ -577,6 +597,84 @@ impl SemanticIndex {
         })
     }
 
+    pub fn initialize_project(
+        &mut self,
+        project: ModelHandle<Project>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        let worktree_scans_complete = project
+            .read(cx)
+            .worktrees(cx)
+            .map(|worktree| {
+                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
+                async move {
+                    scan_complete.await;
+                }
+            })
+            .collect::<Vec<_>>();
+
+        let worktree_db_ids = project
+            .read(cx)
+            .worktrees(cx)
+            .map(|worktree| {
+                self.find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
+            })
+            .collect::<Vec<_>>();
+
+        let _subscription = cx.subscribe(&project, |this, project, event, cx| {
+            if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
+                todo!();
+                // this.project_entries_changed(project, changes, cx, worktree_id);
+            }
+        });
+
+        cx.spawn(|this, mut cx| async move {
+            futures::future::join_all(worktree_scans_complete).await;
+
+            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
+            let worktrees = project.read_with(&cx, |project, cx| {
+                project
+                    .worktrees(cx)
+                    .map(|worktree| worktree.read(cx).snapshot())
+                    .collect::<Vec<_>>()
+            });
+
+            let mut worktree_file_mtimes = HashMap::new();
+            let mut db_ids_by_worktree_id = HashMap::new();
+
+            for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
+                let db_id = db_id?;
+                db_ids_by_worktree_id.insert(worktree.id(), db_id);
+                worktree_file_mtimes.insert(
+                    worktree.id(),
+                    this.read_with(&cx, |this, _| this.get_file_mtimes(db_id))
+                        .await?,
+                );
+            }
+
+            let worktree_db_ids = db_ids_by_worktree_id
+                .iter()
+                .map(|(a, b)| (*a, *b))
+                .collect();
+
+            let (job_count_tx, job_count_rx) = watch::channel_with(0);
+            let job_count_tx = Arc::new(Mutex::new(job_count_tx));
+            this.update(&mut cx, |this, _| {
+                let project_state = ProjectState::new(
+                    _subscription,
+                    worktree_db_ids,
+                    worktree_file_mtimes.clone(),
+                    job_count_rx,
+                    job_count_tx,
+                );
+                this.projects.insert(project.downgrade(), project_state);
+            });
+
+            anyhow::Ok(())
+        })
+        .detach_and_log_err(cx)
+    }
+
     pub fn index_project(
         &mut self,
         project: ModelHandle<Project>,
@@ -605,6 +703,22 @@ impl SemanticIndex {
         let db_update_tx = self.db_update_tx.clone();
         let parsing_files_tx = self.parsing_files_tx.clone();
 
+        let state = self.projects.get(&project.downgrade());
+        let state = if state.is_none() {
+            return Task::Ready(Some(Err(anyhow!("Project not yet initialized"))));
+        } else {
+            state.unwrap()
+        };
+
+        let state = state.clone().to_owned();
+
+        let _subscription = cx.subscribe(&project, |this, project, event, _cx| {
+            if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
+                todo!();
+                // this.project_entries_changed(project, changes, cx, worktree_id);
+            }
+        });
+
         cx.spawn(|this, mut cx| async move {
             futures::future::join_all(worktree_scans_complete).await;
 
@@ -629,20 +743,22 @@ impl SemanticIndex {
                 );
             }
 
+            let worktree_db_ids = db_ids_by_worktree_id
+                .iter()
+                .map(|(a, b)| (*a, *b))
+                .collect();
+
             let (job_count_tx, job_count_rx) = watch::channel_with(0);
             let job_count_tx = Arc::new(Mutex::new(job_count_tx));
             this.update(&mut cx, |this, _| {
-                this.projects.insert(
-                    project.downgrade(),
-                    ProjectState {
-                        worktree_db_ids: db_ids_by_worktree_id
-                            .iter()
-                            .map(|(a, b)| (*a, *b))
-                            .collect(),
-                        outstanding_job_count_rx: job_count_rx.clone(),
-                        _outstanding_job_count_tx: job_count_tx.clone(),
-                    },
+                let project_state = ProjectState::new(
+                    _subscription,
+                    worktree_db_ids,
+                    worktree_file_mtimes.clone(),
+                    job_count_rx.clone(),
+                    job_count_tx.clone(),
                 );
+                this.projects.insert(project.downgrade(), project_state);
             });
 
             cx.background()

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -86,6 +86,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     .unwrap();
 
     let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
+
+    store
+        .update(cx, |store, cx| {
+            store.initialize_project(project.clone(), cx)
+        })
+        .await;
+
     let (file_count, outstanding_file_count) = store
         .update(cx, |store, cx| store.index_project(project.clone(), cx))
         .await