Start work on exposing semantic search via project search view

Max Brunsfeld and Kyle created

Co-authored-by: Kyle <kyle@zed.dev>

Change summary

Cargo.lock                                        |   2 
crates/search/Cargo.toml                          |   1 
crates/search/src/project_search.rs               | 156 +++++
crates/semantic_index/Cargo.toml                  |   1 
crates/semantic_index/src/db.rs                   |  12 
crates/semantic_index/src/embedding.rs            |   7 
crates/semantic_index/src/modal.rs                | 172 ------
crates/semantic_index/src/semantic_index.rs       | 451 ++++++++--------
crates/semantic_index/src/semantic_index_tests.rs |  18 
9 files changed, 397 insertions(+), 423 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -6430,6 +6430,7 @@ dependencies = [
  "menu",
  "postage",
  "project",
+ "semantic_index",
  "serde",
  "serde_derive",
  "serde_json",
@@ -6484,6 +6485,7 @@ dependencies = [
  "matrixmultiply",
  "parking_lot 0.11.2",
  "picker",
+ "postage",
  "project",
  "rand 0.8.5",
  "rpc",

crates/search/Cargo.toml 🔗

@@ -19,6 +19,7 @@ settings = { path = "../settings" }
 theme = { path = "../theme" }
 util = { path = "../util" }
 workspace = { path = "../workspace" }
+semantic_index = { path = "../semantic_index" }
 anyhow.workspace = true
 futures.workspace = true
 log.workspace = true

crates/search/src/project_search.rs 🔗

@@ -2,7 +2,7 @@ use crate::{
     SearchOption, SelectNextMatch, SelectPrevMatch, ToggleCaseSensitive, ToggleRegex,
     ToggleWholeWord,
 };
-use anyhow::Result;
+use anyhow::{Context, Result};
 use collections::HashMap;
 use editor::{
     items::active_match_index, scroll::autoscroll::Autoscroll, Anchor, Editor, MultiBuffer,
@@ -18,7 +18,9 @@ use gpui::{
     Task, View, ViewContext, ViewHandle, WeakModelHandle, WeakViewHandle,
 };
 use menu::Confirm;
+use postage::stream::Stream;
 use project::{search::SearchQuery, Project};
+use semantic_index::SemanticIndex;
 use smallvec::SmallVec;
 use std::{
     any::{Any, TypeId},
@@ -36,7 +38,10 @@ use workspace::{
     ItemNavHistory, Pane, ToolbarItemLocation, ToolbarItemView, Workspace, WorkspaceId,
 };
 
-actions!(project_search, [SearchInNew, ToggleFocus, NextField]);
+actions!(
+    project_search,
+    [SearchInNew, ToggleFocus, NextField, ToggleSemanticSearch]
+);
 
 #[derive(Default)]
 struct ActiveSearches(HashMap<WeakModelHandle<Project>, WeakViewHandle<ProjectSearchView>>);
@@ -92,6 +97,7 @@ pub struct ProjectSearchView {
     case_sensitive: bool,
     whole_word: bool,
     regex: bool,
+    semantic: Option<SemanticSearchState>,
     panels_with_errors: HashSet<InputPanel>,
     active_match_index: Option<usize>,
     search_id: usize,
@@ -100,6 +106,13 @@ pub struct ProjectSearchView {
     excluded_files_editor: ViewHandle<Editor>,
 }
 
+struct SemanticSearchState {
+    file_count: usize,
+    outstanding_file_count: usize,
+    _progress_task: Task<()>,
+    search_task: Option<Task<Result<()>>>,
+}
+
 pub struct ProjectSearchBar {
     active_project_search: Option<ViewHandle<ProjectSearchView>>,
     subscription: Option<Subscription>,
@@ -198,12 +211,25 @@ impl View for ProjectSearchView {
 
             let theme = theme::current(cx).clone();
             let text = if self.query_editor.read(cx).text(cx).is_empty() {
-                ""
+                Cow::Borrowed("")
+            } else if let Some(semantic) = &self.semantic {
+                if semantic.search_task.is_some() {
+                    Cow::Borrowed("Searching...")
+                } else if semantic.outstanding_file_count > 0 {
+                    Cow::Owned(format!(
+                        "Indexing. {} of {}...",
+                        semantic.file_count - semantic.outstanding_file_count,
+                        semantic.file_count
+                    ))
+                } else {
+                    Cow::Borrowed("Indexing complete")
+                }
             } else if model.pending_search.is_some() {
-                "Searching..."
+                Cow::Borrowed("Searching...")
             } else {
-                "No results"
+                Cow::Borrowed("No results")
             };
+
             MouseEventHandler::<Status, _>::new(0, cx, |_, _| {
                 Label::new(text, theme.search.results_status.clone())
                     .aligned()
@@ -499,6 +525,7 @@ impl ProjectSearchView {
             case_sensitive,
             whole_word,
             regex,
+            semantic: None,
             panels_with_errors: HashSet::new(),
             active_match_index: None,
             query_editor_was_focused: false,
@@ -563,6 +590,35 @@ impl ProjectSearchView {
     }
 
     fn search(&mut self, cx: &mut ViewContext<Self>) {
+        if let Some(semantic) = &mut self.semantic {
+            if semantic.outstanding_file_count > 0 {
+                return;
+            }
+
+            let search_phrase = self.query_editor.read(cx).text(cx);
+            let project = self.model.read(cx).project.clone();
+            if let Some(semantic_index) = SemanticIndex::global(cx) {
+                let search_task = semantic_index.update(cx, |semantic_index, cx| {
+                    semantic_index.search_project(project, search_phrase, 10, cx)
+                });
+                semantic.search_task = Some(cx.spawn(|this, mut cx| async move {
+                    let results = search_task.await.context("search task")?;
+
+                    this.update(&mut cx, |this, cx| {
+                        dbg!(&results);
+                        // TODO: Update results
+
+                        if let Some(semantic) = &mut this.semantic {
+                            semantic.search_task = None;
+                        }
+                    })?;
+
+                    anyhow::Ok(())
+                }));
+            }
+            return;
+        }
+
         if let Some(query) = self.build_search_query(cx) {
             self.model.update(cx, |model, cx| model.search(query, cx));
         }
@@ -876,6 +932,59 @@ impl ProjectSearchBar {
         }
     }
 
+    fn toggle_semantic_search(&mut self, cx: &mut ViewContext<Self>) -> bool {
+        if let Some(search_view) = self.active_project_search.as_ref() {
+            search_view.update(cx, |search_view, cx| {
+                if search_view.semantic.is_some() {
+                    search_view.semantic = None;
+                } else if let Some(semantic_index) = SemanticIndex::global(cx) {
+                    // TODO: confirm that it's ok to send this project
+
+                    let project = search_view.model.read(cx).project.clone();
+                    let index_task = semantic_index.update(cx, |semantic_index, cx| {
+                        semantic_index.index_project(project, cx)
+                    });
+
+                    cx.spawn(|search_view, mut cx| async move {
+                        let (files_to_index, mut files_remaining_rx) = index_task.await?;
+
+                        search_view.update(&mut cx, |search_view, cx| {
+                            search_view.semantic = Some(SemanticSearchState {
+                                file_count: files_to_index,
+                                outstanding_file_count: files_to_index,
+                                search_task: None,
+                                _progress_task: cx.spawn(|search_view, mut cx| async move {
+                                    while let Some(count) = files_remaining_rx.recv().await {
+                                        search_view
+                                            .update(&mut cx, |search_view, cx| {
+                                                if let Some(semantic_search_state) =
+                                                    &mut search_view.semantic
+                                                {
+                                                    semantic_search_state.outstanding_file_count =
+                                                        count;
+                                                    cx.notify();
+                                                    if count == 0 {
+                                                        return;
+                                                    }
+                                                }
+                                            })
+                                            .ok();
+                                    }
+                                }),
+                            });
+                        })?;
+                        anyhow::Ok(())
+                    })
+                    .detach_and_log_err(cx);
+                }
+            });
+            cx.notify();
+            true
+        } else {
+            false
+        }
+    }
+
     fn render_nav_button(
         &self,
         icon: &'static str,
@@ -953,6 +1062,42 @@ impl ProjectSearchBar {
         .into_any()
     }
 
+    fn render_semantic_search_button(&self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
+        let tooltip_style = theme::current(cx).tooltip.clone();
+        let is_active = if let Some(search) = self.active_project_search.as_ref() {
+            let search = search.read(cx);
+            search.semantic.is_some()
+        } else {
+            false
+        };
+
+        let region_id = 3;
+
+        MouseEventHandler::<Self, _>::new(region_id, cx, |state, cx| {
+            let theme = theme::current(cx);
+            let style = theme
+                .search
+                .option_button
+                .in_state(is_active)
+                .style_for(state);
+            Label::new("Semantic", style.text.clone())
+                .contained()
+                .with_style(style.container)
+        })
+        .on_click(MouseButton::Left, move |_, this, cx| {
+            this.toggle_semantic_search(cx);
+        })
+        .with_cursor_style(CursorStyle::PointingHand)
+        .with_tooltip::<Self>(
+            region_id,
+            format!("Toggle Semantic Search"),
+            Some(Box::new(ToggleSemanticSearch)),
+            tooltip_style,
+            cx,
+        )
+        .into_any()
+    }
+
     fn is_option_enabled(&self, option: SearchOption, cx: &AppContext) -> bool {
         if let Some(search) = self.active_project_search.as_ref() {
             let search = search.read(cx);
@@ -1049,6 +1194,7 @@ impl View for ProjectSearchBar {
                         )
                         .with_child(
                             Flex::row()
+                                .with_child(self.render_semantic_search_button(cx))
                                 .with_child(self.render_option_button(
                                     "Case",
                                     SearchOption::CaseSensitive,

crates/semantic_index/Cargo.toml 🔗

@@ -20,6 +20,7 @@ editor = { path = "../editor" }
 rpc = { path = "../rpc" }
 settings = { path = "../settings" }
 anyhow.workspace = true
+postage.workspace = true
 futures.workspace = true
 smol.workspace = true
 rusqlite = { version = "0.27.0", features = ["blob", "array", "modern_sqlite"] }

crates/semantic_index/src/db.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{parsing::Document, SEMANTIC_INDEX_VERSION};
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, Context, Result};
 use project::Fs;
 use rpc::proto::Timestamp;
 use rusqlite::{
@@ -76,14 +76,14 @@ impl VectorDatabase {
         self.db
             .execute(
                 "
-                    DROP TABLE semantic_index_config;
-                    DROP TABLE worktrees;
-                    DROP TABLE files;
-                    DROP TABLE documents;
+                DROP TABLE IF EXISTS documents;
+                DROP TABLE IF EXISTS files;
+                DROP TABLE IF EXISTS worktrees;
+                DROP TABLE IF EXISTS semantic_index_config;
                 ",
                 [],
             )
-            .ok();
+            .context("failed to drop tables")?;
 
         // Initialize Vector Databasing Tables
         self.db.execute(

crates/semantic_index/src/embedding.rs 🔗

@@ -86,6 +86,7 @@ impl OpenAIEmbeddings {
     async fn send_request(&self, api_key: &str, spans: Vec<&str>) -> Result<Response<AsyncBody>> {
         let request = Request::post("https://api.openai.com/v1/embeddings")
             .redirect_policy(isahc::config::RedirectPolicy::Follow)
+            .timeout(Duration::from_secs(4))
             .header("Content-Type", "application/json")
             .header("Authorization", format!("Bearer {}", api_key))
             .body(
@@ -133,7 +134,11 @@ impl EmbeddingProvider for OpenAIEmbeddings {
                     self.executor.timer(delay).await;
                 }
                 StatusCode::BAD_REQUEST => {
-                    log::info!("BAD REQUEST: {:?}", &response.status());
+                    log::info!(
+                        "BAD REQUEST: {:?} {:?}",
+                        &response.status(),
+                        response.body()
+                    );
                     // Don't worry about delaying bad request, as we can assume
                     // we haven't been rate limited yet.
                     for span in spans.iter_mut() {

crates/semantic_index/src/modal.rs 🔗

@@ -1,172 +0,0 @@
-use crate::{SearchResult, SemanticIndex};
-use editor::{scroll::autoscroll::Autoscroll, Editor};
-use gpui::{
-    actions, elements::*, AnyElement, AppContext, ModelHandle, MouseState, Task, ViewContext,
-    WeakViewHandle,
-};
-use picker::{Picker, PickerDelegate, PickerEvent};
-use project::{Project, ProjectPath};
-use std::{collections::HashMap, sync::Arc, time::Duration};
-use util::ResultExt;
-use workspace::Workspace;
-
-const MIN_QUERY_LEN: usize = 5;
-const EMBEDDING_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(500);
-
-actions!(semantic_search, [Toggle]);
-
-pub type SemanticSearch = Picker<SemanticSearchDelegate>;
-
-pub struct SemanticSearchDelegate {
-    workspace: WeakViewHandle<Workspace>,
-    project: ModelHandle<Project>,
-    semantic_index: ModelHandle<SemanticIndex>,
-    selected_match_index: usize,
-    matches: Vec<SearchResult>,
-    history: HashMap<String, Vec<SearchResult>>,
-}
-
-impl SemanticSearchDelegate {
-    // This is currently searching on every keystroke,
-    // This is wildly overkill, and has the potential to get expensive
-    // We will need to update this to throttle searching
-    pub fn new(
-        workspace: WeakViewHandle<Workspace>,
-        project: ModelHandle<Project>,
-        semantic_index: ModelHandle<SemanticIndex>,
-    ) -> Self {
-        Self {
-            workspace,
-            project,
-            semantic_index,
-            selected_match_index: 0,
-            matches: vec![],
-            history: HashMap::new(),
-        }
-    }
-}
-
-impl PickerDelegate for SemanticSearchDelegate {
-    fn placeholder_text(&self) -> Arc<str> {
-        "Search repository in natural language...".into()
-    }
-
-    fn confirm(&mut self, cx: &mut ViewContext<SemanticSearch>) {
-        if let Some(search_result) = self.matches.get(self.selected_match_index) {
-            // Open Buffer
-            let search_result = search_result.clone();
-            let buffer = self.project.update(cx, |project, cx| {
-                project.open_buffer(
-                    ProjectPath {
-                        worktree_id: search_result.worktree_id,
-                        path: search_result.file_path.clone().into(),
-                    },
-                    cx,
-                )
-            });
-
-            let workspace = self.workspace.clone();
-            let position = search_result.clone().byte_range.start;
-            cx.spawn(|_, mut cx| async move {
-                let buffer = buffer.await?;
-                workspace.update(&mut cx, |workspace, cx| {
-                    let editor = workspace.open_project_item::<Editor>(buffer, cx);
-                    editor.update(cx, |editor, cx| {
-                        editor.change_selections(Some(Autoscroll::center()), cx, |s| {
-                            s.select_ranges([position..position])
-                        });
-                    });
-                })?;
-                Ok::<_, anyhow::Error>(())
-            })
-            .detach_and_log_err(cx);
-            cx.emit(PickerEvent::Dismiss);
-        }
-    }
-
-    fn dismissed(&mut self, _cx: &mut ViewContext<SemanticSearch>) {}
-
-    fn match_count(&self) -> usize {
-        self.matches.len()
-    }
-
-    fn selected_index(&self) -> usize {
-        self.selected_match_index
-    }
-
-    fn set_selected_index(&mut self, ix: usize, _cx: &mut ViewContext<SemanticSearch>) {
-        self.selected_match_index = ix;
-    }
-
-    fn update_matches(&mut self, query: String, cx: &mut ViewContext<SemanticSearch>) -> Task<()> {
-        log::info!("Searching for {:?}...", query);
-        if query.len() < MIN_QUERY_LEN {
-            log::info!("Query below minimum length");
-            return Task::ready(());
-        }
-
-        let semantic_index = self.semantic_index.clone();
-        let project = self.project.clone();
-        cx.spawn(|this, mut cx| async move {
-            cx.background().timer(EMBEDDING_DEBOUNCE_INTERVAL).await;
-
-            let retrieved_cached = this.update(&mut cx, |this, _| {
-                let delegate = this.delegate_mut();
-                if delegate.history.contains_key(&query) {
-                    let historic_results = delegate.history.get(&query).unwrap().to_owned();
-                    delegate.matches = historic_results.clone();
-                    true
-                } else {
-                    false
-                }
-            });
-
-            if let Some(retrieved) = retrieved_cached.log_err() {
-                if !retrieved {
-                    let task = semantic_index.update(&mut cx, |store, cx| {
-                        store.search_project(project.clone(), query.to_string(), 10, cx)
-                    });
-
-                    if let Some(results) = task.await.log_err() {
-                        log::info!("Not queried previously, searching...");
-                        this.update(&mut cx, |this, _| {
-                            let delegate = this.delegate_mut();
-                            delegate.matches = results.clone();
-                            delegate.history.insert(query, results);
-                        })
-                        .ok();
-                    }
-                } else {
-                    log::info!("Already queried, retrieved directly from cached history");
-                }
-            }
-        })
-    }
-
-    fn render_match(
-        &self,
-        ix: usize,
-        mouse_state: &mut MouseState,
-        selected: bool,
-        cx: &AppContext,
-    ) -> AnyElement<Picker<Self>> {
-        let theme = theme::current(cx);
-        let style = &theme.picker.item;
-        let current_style = style.in_state(selected).style_for(mouse_state);
-
-        let search_result = &self.matches[ix];
-
-        let path = search_result.file_path.to_string_lossy();
-        let name = search_result.name.clone();
-
-        Flex::column()
-            .with_child(Text::new(name, current_style.label.text.clone()).with_soft_wrap(false))
-            .with_child(Label::new(
-                path.to_string(),
-                style.inactive_state().default.label.clone(),
-            ))
-            .contained()
-            .with_style(current_style.container)
-            .into_any()
-    }
-}

crates/semantic_index/src/semantic_index.rs 🔗

@@ -1,6 +1,5 @@
 mod db;
 mod embedding;
-mod modal;
 mod parsing;
 mod semantic_index_settings;
 
@@ -12,25 +11,20 @@ use anyhow::{anyhow, Result};
 use db::VectorDatabase;
 use embedding::{EmbeddingProvider, OpenAIEmbeddings};
 use futures::{channel::oneshot, Future};
-use gpui::{
-    AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, ViewContext,
-    WeakModelHandle,
-};
+use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
 use language::{Language, LanguageRegistry};
-use modal::{SemanticSearch, SemanticSearchDelegate, Toggle};
 use parking_lot::Mutex;
 use parsing::{CodeContextRetriever, Document, PARSEABLE_ENTIRE_FILE_TYPES};
+use postage::watch;
 use project::{Fs, Project, WorktreeId};
 use smol::channel;
 use std::{
-    collections::{HashMap, HashSet},
+    collections::HashMap,
+    mem,
     ops::Range,
     path::{Path, PathBuf},
-    sync::{
-        atomic::{self, AtomicUsize},
-        Arc, Weak,
-    },
-    time::{Instant, SystemTime},
+    sync::{Arc, Weak},
+    time::SystemTime,
 };
 use util::{
     channel::{ReleaseChannel, RELEASE_CHANNEL, RELEASE_CHANNEL_NAME},
@@ -38,9 +32,8 @@ use util::{
     paths::EMBEDDINGS_DIR,
     ResultExt,
 };
-use workspace::{Workspace, WorkspaceCreated};
 
-const SEMANTIC_INDEX_VERSION: usize = 1;
+const SEMANTIC_INDEX_VERSION: usize = 3;
 const EMBEDDINGS_BATCH_SIZE: usize = 150;
 
 pub fn init(
@@ -55,25 +48,6 @@ pub fn init(
         .join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
         .join("embeddings_db");
 
-    SemanticSearch::init(cx);
-    cx.add_action(
-        |workspace: &mut Workspace, _: &Toggle, cx: &mut ViewContext<Workspace>| {
-            if cx.has_global::<ModelHandle<SemanticIndex>>() {
-                let semantic_index = cx.global::<ModelHandle<SemanticIndex>>().clone();
-                workspace.toggle_modal(cx, |workspace, cx| {
-                    let project = workspace.project().clone();
-                    let workspace = cx.weak_handle();
-                    cx.add_view(|cx| {
-                        SemanticSearch::new(
-                            SemanticSearchDelegate::new(workspace, project, semantic_index),
-                            cx,
-                        )
-                    })
-                });
-            }
-        },
-    );
-
     if *RELEASE_CHANNEL == ReleaseChannel::Stable
         || !settings::get::<SemanticIndexSettings>(cx).enabled
     {
@@ -95,21 +69,6 @@ pub fn init(
 
         cx.update(|cx| {
             cx.set_global(semantic_index.clone());
-            cx.subscribe_global::<WorkspaceCreated, _>({
-                let semantic_index = semantic_index.clone();
-                move |event, cx| {
-                    let workspace = &event.0;
-                    if let Some(workspace) = workspace.upgrade(cx) {
-                        let project = workspace.read(cx).project().clone();
-                        if project.read(cx).is_local() {
-                            semantic_index.update(cx, |store, cx| {
-                                store.index_project(project, cx).detach();
-                            });
-                        }
-                    }
-                }
-            })
-            .detach();
         });
 
         anyhow::Ok(())
@@ -128,20 +87,17 @@ pub struct SemanticIndex {
     _embed_batch_task: Task<()>,
     _batch_files_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
-    next_job_id: Arc<AtomicUsize>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 }
 
 struct ProjectState {
     worktree_db_ids: Vec<(WorktreeId, i64)>,
-    outstanding_jobs: Arc<Mutex<HashSet<JobId>>>,
+    outstanding_job_count_rx: watch::Receiver<usize>,
+    outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
 }
 
-type JobId = usize;
-
 struct JobHandle {
-    id: JobId,
-    set: Weak<Mutex<HashSet<JobId>>>,
+    tx: Weak<Mutex<watch::Sender<usize>>>,
 }
 
 impl ProjectState {
@@ -221,6 +177,14 @@ enum EmbeddingJob {
 }
 
 impl SemanticIndex {
+    pub fn global(cx: &AppContext) -> Option<ModelHandle<SemanticIndex>> {
+        if cx.has_global::<ModelHandle<Self>>() {
+            Some(cx.global::<ModelHandle<SemanticIndex>>().clone())
+        } else {
+            None
+        }
+    }
+
     async fn new(
         fs: Arc<dyn Fs>,
         database_url: PathBuf,
@@ -236,184 +200,69 @@ impl SemanticIndex {
             .await?;
 
         Ok(cx.add_model(|cx| {
-            // paths_tx -> embeddings_tx -> db_update_tx
-
-            //db_update_tx/rx: Updating Database
+            // Perform database operations
             let (db_update_tx, db_update_rx) = channel::unbounded();
-            let _db_update_task = cx.background().spawn(async move {
-                while let Ok(job) = db_update_rx.recv().await {
-                    match job {
-                        DbOperation::InsertFile {
-                            worktree_id,
-                            documents,
-                            path,
-                            mtime,
-                            job_handle,
-                        } => {
-                            db.insert_file(worktree_id, path, mtime, documents)
-                                .log_err();
-                            drop(job_handle)
-                        }
-                        DbOperation::Delete { worktree_id, path } => {
-                            db.delete_file(worktree_id, path).log_err();
-                        }
-                        DbOperation::FindOrCreateWorktree { path, sender } => {
-                            let id = db.find_or_create_worktree(&path);
-                            sender.send(id).ok();
-                        }
-                        DbOperation::FileMTimes {
-                            worktree_id: worktree_db_id,
-                            sender,
-                        } => {
-                            let file_mtimes = db.get_file_mtimes(worktree_db_id);
-                            sender.send(file_mtimes).ok();
-                        }
+            let _db_update_task = cx.background().spawn({
+                async move {
+                    while let Ok(job) = db_update_rx.recv().await {
+                        Self::run_db_operation(&db, job)
                     }
                 }
             });
 
-            // embed_tx/rx: Embed Batch and Send to Database
+            // Group documents into batches and send them to the embedding provider.
             let (embed_batch_tx, embed_batch_rx) =
                 channel::unbounded::<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>();
             let _embed_batch_task = cx.background().spawn({
                 let db_update_tx = db_update_tx.clone();
                 let embedding_provider = embedding_provider.clone();
                 async move {
-                    while let Ok(mut embeddings_queue) = embed_batch_rx.recv().await {
-                        // Construct Batch
-                        let mut batch_documents = vec![];
-                        for (_, documents, _, _, _) in embeddings_queue.iter() {
-                            batch_documents
-                                .extend(documents.iter().map(|document| document.content.as_str()));
-                        }
-
-                        if let Ok(embeddings) =
-                            embedding_provider.embed_batch(batch_documents).await
-                        {
-                            log::trace!(
-                                "created {} embeddings for {} files",
-                                embeddings.len(),
-                                embeddings_queue.len(),
-                            );
-
-                            let mut i = 0;
-                            let mut j = 0;
-
-                            for embedding in embeddings.iter() {
-                                while embeddings_queue[i].1.len() == j {
-                                    i += 1;
-                                    j = 0;
-                                }
-
-                                embeddings_queue[i].1[j].embedding = embedding.to_owned();
-                                j += 1;
-                            }
-
-                            for (worktree_id, documents, path, mtime, job_handle) in
-                                embeddings_queue.into_iter()
-                            {
-                                for document in documents.iter() {
-                                    // TODO: Update this so it doesn't panic
-                                    assert!(
-                                        document.embedding.len() > 0,
-                                        "Document Embedding Not Complete"
-                                    );
-                                }
-
-                                db_update_tx
-                                    .send(DbOperation::InsertFile {
-                                        worktree_id,
-                                        documents,
-                                        path,
-                                        mtime,
-                                        job_handle,
-                                    })
-                                    .await
-                                    .unwrap();
-                            }
-                        }
+                    while let Ok(embeddings_queue) = embed_batch_rx.recv().await {
+                        Self::compute_embeddings_for_batch(
+                            embeddings_queue,
+                            &embedding_provider,
+                            &db_update_tx,
+                        )
+                        .await;
                     }
                 }
             });
 
-            // batch_tx/rx: Batch Files to Send for Embeddings
+            // Group documents into batches and send them to the embedding provider.
             let (batch_files_tx, batch_files_rx) = channel::unbounded::<EmbeddingJob>();
             let _batch_files_task = cx.background().spawn(async move {
                 let mut queue_len = 0;
                 let mut embeddings_queue = vec![];
-
                 while let Ok(job) = batch_files_rx.recv().await {
-                    let should_flush = match job {
-                        EmbeddingJob::Enqueue {
-                            documents,
-                            worktree_id,
-                            path,
-                            mtime,
-                            job_handle,
-                        } => {
-                            queue_len += &documents.len();
-                            embeddings_queue.push((
-                                worktree_id,
-                                documents,
-                                path,
-                                mtime,
-                                job_handle,
-                            ));
-                            queue_len >= EMBEDDINGS_BATCH_SIZE
-                        }
-                        EmbeddingJob::Flush => true,
-                    };
-
-                    if should_flush {
-                        embed_batch_tx.try_send(embeddings_queue).unwrap();
-                        embeddings_queue = vec![];
-                        queue_len = 0;
-                    }
+                    Self::enqueue_documents_to_embed(
+                        job,
+                        &mut queue_len,
+                        &mut embeddings_queue,
+                        &embed_batch_tx,
+                    );
                 }
             });
 
-            // parsing_files_tx/rx: Parsing Files to Embeddable Documents
+            // Parse files into embeddable documents.
             let (parsing_files_tx, parsing_files_rx) = channel::unbounded::<PendingFile>();
-
             let mut _parsing_files_tasks = Vec::new();
             for _ in 0..cx.background().num_cpus() {
                 let fs = fs.clone();
                 let parsing_files_rx = parsing_files_rx.clone();
                 let batch_files_tx = batch_files_tx.clone();
+                let db_update_tx = db_update_tx.clone();
                 _parsing_files_tasks.push(cx.background().spawn(async move {
                     let mut retriever = CodeContextRetriever::new();
                     while let Ok(pending_file) = parsing_files_rx.recv().await {
-                        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err()
-                        {
-                            if let Some(documents) = retriever
-                                .parse_file(
-                                    &pending_file.relative_path,
-                                    &content,
-                                    pending_file.language,
-                                )
-                                .log_err()
-                            {
-                                log::trace!(
-                                    "parsed path {:?}: {} documents",
-                                    pending_file.relative_path,
-                                    documents.len()
-                                );
-
-                                batch_files_tx
-                                    .try_send(EmbeddingJob::Enqueue {
-                                        worktree_id: pending_file.worktree_db_id,
-                                        path: pending_file.relative_path,
-                                        mtime: pending_file.modified_time,
-                                        job_handle: pending_file.job_handle,
-                                        documents,
-                                    })
-                                    .unwrap();
-                            }
-                        }
-
-                        if parsing_files_rx.len() == 0 {
-                            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
-                        }
+                        Self::parse_file(
+                            &fs,
+                            pending_file,
+                            &mut retriever,
+                            &batch_files_tx,
+                            &parsing_files_rx,
+                            &db_update_tx,
+                        )
+                        .await;
                     }
                 }));
             }
@@ -424,7 +273,6 @@ impl SemanticIndex {
                 embedding_provider,
                 language_registry,
                 db_update_tx,
-                next_job_id: Default::default(),
                 parsing_files_tx,
                 _db_update_task,
                 _embed_batch_task,
@@ -435,6 +283,167 @@ impl SemanticIndex {
         }))
     }
 
+    fn run_db_operation(db: &VectorDatabase, job: DbOperation) {
+        match job {
+            DbOperation::InsertFile {
+                worktree_id,
+                documents,
+                path,
+                mtime,
+                job_handle,
+            } => {
+                db.insert_file(worktree_id, path, mtime, documents)
+                    .log_err();
+                drop(job_handle)
+            }
+            DbOperation::Delete { worktree_id, path } => {
+                db.delete_file(worktree_id, path).log_err();
+            }
+            DbOperation::FindOrCreateWorktree { path, sender } => {
+                let id = db.find_or_create_worktree(&path);
+                sender.send(id).ok();
+            }
+            DbOperation::FileMTimes {
+                worktree_id: worktree_db_id,
+                sender,
+            } => {
+                let file_mtimes = db.get_file_mtimes(worktree_db_id);
+                sender.send(file_mtimes).ok();
+            }
+        }
+    }
+
+    async fn compute_embeddings_for_batch(
+        mut embeddings_queue: Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
+        embedding_provider: &Arc<dyn EmbeddingProvider>,
+        db_update_tx: &channel::Sender<DbOperation>,
+    ) {
+        let mut batch_documents = vec![];
+        for (_, documents, _, _, _) in embeddings_queue.iter() {
+            batch_documents.extend(documents.iter().map(|document| document.content.as_str()));
+        }
+
+        if let Ok(embeddings) = embedding_provider.embed_batch(batch_documents).await {
+            log::trace!(
+                "created {} embeddings for {} files",
+                embeddings.len(),
+                embeddings_queue.len(),
+            );
+
+            let mut i = 0;
+            let mut j = 0;
+
+            for embedding in embeddings.iter() {
+                while embeddings_queue[i].1.len() == j {
+                    i += 1;
+                    j = 0;
+                }
+
+                embeddings_queue[i].1[j].embedding = embedding.to_owned();
+                j += 1;
+            }
+
+            for (worktree_id, documents, path, mtime, job_handle) in embeddings_queue.into_iter() {
+                // for document in documents.iter() {
+                //     // TODO: Update this so it doesn't panic
+                //     assert!(
+                //         document.embedding.len() > 0,
+                //         "Document Embedding Not Complete"
+                //     );
+                // }
+
+                db_update_tx
+                    .send(DbOperation::InsertFile {
+                        worktree_id,
+                        documents,
+                        path,
+                        mtime,
+                        job_handle,
+                    })
+                    .await
+                    .unwrap();
+            }
+        }
+    }
+
+    fn enqueue_documents_to_embed(
+        job: EmbeddingJob,
+        queue_len: &mut usize,
+        embeddings_queue: &mut Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>,
+        embed_batch_tx: &channel::Sender<Vec<(i64, Vec<Document>, PathBuf, SystemTime, JobHandle)>>,
+    ) {
+        let should_flush = match job {
+            EmbeddingJob::Enqueue {
+                documents,
+                worktree_id,
+                path,
+                mtime,
+                job_handle,
+            } => {
+                *queue_len += &documents.len();
+                embeddings_queue.push((worktree_id, documents, path, mtime, job_handle));
+                *queue_len >= EMBEDDINGS_BATCH_SIZE
+            }
+            EmbeddingJob::Flush => true,
+        };
+
+        if should_flush {
+            embed_batch_tx
+                .try_send(mem::take(embeddings_queue))
+                .unwrap();
+            *queue_len = 0;
+        }
+    }
+
+    async fn parse_file(
+        fs: &Arc<dyn Fs>,
+        pending_file: PendingFile,
+        retriever: &mut CodeContextRetriever,
+        batch_files_tx: &channel::Sender<EmbeddingJob>,
+        parsing_files_rx: &channel::Receiver<PendingFile>,
+        db_update_tx: &channel::Sender<DbOperation>,
+    ) {
+        if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
+            if let Some(documents) = retriever
+                .parse_file(&pending_file.relative_path, &content, pending_file.language)
+                .log_err()
+            {
+                log::trace!(
+                    "parsed path {:?}: {} documents",
+                    pending_file.relative_path,
+                    documents.len()
+                );
+
+                if documents.len() == 0 {
+                    db_update_tx
+                        .send(DbOperation::InsertFile {
+                            worktree_id: pending_file.worktree_db_id,
+                            documents,
+                            path: pending_file.relative_path,
+                            mtime: pending_file.modified_time,
+                            job_handle: pending_file.job_handle,
+                        })
+                        .await
+                        .unwrap();
+                } else {
+                    batch_files_tx
+                        .try_send(EmbeddingJob::Enqueue {
+                            worktree_id: pending_file.worktree_db_id,
+                            path: pending_file.relative_path,
+                            mtime: pending_file.modified_time,
+                            job_handle: pending_file.job_handle,
+                            documents,
+                        })
+                        .unwrap();
+                }
+            }
+        }
+
+        if parsing_files_rx.len() == 0 {
+            batch_files_tx.try_send(EmbeddingJob::Flush).unwrap();
+        }
+    }
+
     fn find_or_create_worktree(&self, path: PathBuf) -> impl Future<Output = Result<i64>> {
         let (tx, rx) = oneshot::channel();
         self.db_update_tx
@@ -457,11 +466,11 @@ impl SemanticIndex {
         async move { rx.await? }
     }
 
-    fn index_project(
+    pub fn index_project(
         &mut self,
         project: ModelHandle<Project>,
         cx: &mut ModelContext<Self>,
-    ) -> Task<Result<usize>> {
+    ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
         let worktree_scans_complete = project
             .read(cx)
             .worktrees(cx)
@@ -483,7 +492,6 @@ impl SemanticIndex {
         let language_registry = self.language_registry.clone();
         let db_update_tx = self.db_update_tx.clone();
         let parsing_files_tx = self.parsing_files_tx.clone();
-        let next_job_id = self.next_job_id.clone();
 
         cx.spawn(|this, mut cx| async move {
             futures::future::join_all(worktree_scans_complete).await;
@@ -509,8 +517,8 @@ impl SemanticIndex {
                 );
             }
 
-            // let mut pending_files: Vec<(PathBuf, ((i64, PathBuf, Arc<Language>, SystemTime), SystemTime))> = vec![];
-            let outstanding_jobs = Arc::new(Mutex::new(HashSet::new()));
+            let (job_count_tx, job_count_rx) = watch::channel_with(0);
+            let job_count_tx = Arc::new(Mutex::new(job_count_tx));
             this.update(&mut cx, |this, _| {
                 this.projects.insert(
                     project.downgrade(),
@@ -519,7 +527,8 @@ impl SemanticIndex {
                             .iter()
                             .map(|(a, b)| (*a, *b))
                             .collect(),
-                        outstanding_jobs: outstanding_jobs.clone(),
+                        outstanding_job_count_rx: job_count_rx.clone(),
+                        outstanding_job_count_tx: job_count_tx.clone(),
                     },
                 );
             });
@@ -527,7 +536,6 @@ impl SemanticIndex {
             cx.background()
                 .spawn(async move {
                     let mut count = 0;
-                    let t0 = Instant::now();
                     for worktree in worktrees.into_iter() {
                         let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
                         for file in worktree.files(false, 0) {
@@ -552,14 +560,11 @@ impl SemanticIndex {
                                     .map_or(false, |existing_mtime| existing_mtime == file.mtime);
 
                                 if !already_stored {
-                                    log::trace!("sending for parsing: {:?}", path_buf);
                                     count += 1;
-                                    let job_id = next_job_id.fetch_add(1, atomic::Ordering::SeqCst);
+                                    *job_count_tx.lock().borrow_mut() += 1;
                                     let job_handle = JobHandle {
-                                        id: job_id,
-                                        set: Arc::downgrade(&outstanding_jobs),
+                                        tx: Arc::downgrade(&job_count_tx),
                                     };
-                                    outstanding_jobs.lock().insert(job_id);
                                     parsing_files_tx
                                         .try_send(PendingFile {
                                             worktree_db_id: db_ids_by_worktree_id[&worktree.id()],
@@ -582,27 +587,22 @@ impl SemanticIndex {
                                 .unwrap();
                         }
                     }
-                    log::trace!(
-                        "parsing worktree completed in {:?}",
-                        t0.elapsed().as_millis()
-                    );
 
-                    Ok(count)
+                    anyhow::Ok((count, job_count_rx))
                 })
                 .await
         })
     }
 
-    pub fn remaining_files_to_index_for_project(
+    pub fn outstanding_job_count_rx(
         &self,
         project: &ModelHandle<Project>,
-    ) -> Option<usize> {
+    ) -> Option<watch::Receiver<usize>> {
         Some(
             self.projects
                 .get(&project.downgrade())?
-                .outstanding_jobs
-                .lock()
-                .len(),
+                .outstanding_job_count_rx
+                .clone(),
         )
     }
 
@@ -678,8 +678,9 @@ impl Entity for SemanticIndex {
 
 impl Drop for JobHandle {
     fn drop(&mut self) {
-        if let Some(set) = self.set.upgrade() {
-            set.lock().remove(&self.id);
+        if let Some(tx) = self.tx.upgrade() {
+            let mut tx = tx.lock();
+            *tx.borrow_mut() -= 1;
         }
     }
 }

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -88,18 +88,13 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     let worktree_id = project.read_with(cx, |project, cx| {
         project.worktrees(cx).next().unwrap().read(cx).id()
     });
-    let file_count = store
+    let (file_count, outstanding_file_count) = store
         .update(cx, |store, cx| store.index_project(project.clone(), cx))
         .await
         .unwrap();
     assert_eq!(file_count, 3);
     cx.foreground().run_until_parked();
-    store.update(cx, |store, _cx| {
-        assert_eq!(
-            store.remaining_files_to_index_for_project(&project),
-            Some(0)
-        );
-    });
+    assert_eq!(*outstanding_file_count.borrow(), 0);
 
     let search_results = store
         .update(cx, |store, cx| {
@@ -128,19 +123,14 @@ async fn test_semantic_index(cx: &mut TestAppContext) {
     cx.foreground().run_until_parked();
 
     let prev_embedding_count = embedding_provider.embedding_count();
-    let file_count = store
+    let (file_count, outstanding_file_count) = store
         .update(cx, |store, cx| store.index_project(project.clone(), cx))
         .await
         .unwrap();
     assert_eq!(file_count, 1);
 
     cx.foreground().run_until_parked();
-    store.update(cx, |store, _cx| {
-        assert_eq!(
-            store.remaining_files_to_index_for_project(&project),
-            Some(0)
-        );
-    });
+    assert_eq!(*outstanding_file_count.borrow(), 0);
 
     assert_eq!(
         embedding_provider.embedding_count() - prev_embedding_count,