Rework how we track projects and worktrees in semantic index (#2938)

Antonio Scandurra created

This pull request introduces several improvements to the semantic search
experience. We're still missing collaboration and searching modified
buffers, which we'll tackle after we take a detour into reducing the
number of tokens used to generate embeddings.

Release Notes:

- Fixed a bug that could prevent semantic search from working when
deploying right after opening a project.
- Fixed a panic that could sometimes occur when using semantic search
while simultaneously changing a file.
- Fixed a bug that prevented semantic search from including new
worktrees when adding them to a project.

Change summary

Cargo.lock                                        |   1 
crates/search/src/project_search.rs               |  42 
crates/semantic_index/Cargo.toml                  |   2 
crates/semantic_index/src/db.rs                   |  69 
crates/semantic_index/src/embedding_queue.rs      | 134 +-
crates/semantic_index/src/parsing.rs              |  66 
crates/semantic_index/src/semantic_index.rs       | 851 +++++++++-------
crates/semantic_index/src/semantic_index_tests.rs |  68 
8 files changed, 635 insertions(+), 598 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -6722,6 +6722,7 @@ dependencies = [
  "anyhow",
  "async-trait",
  "bincode",
+ "collections",
  "ctor",
  "editor",
  "env_logger 0.9.3",

crates/search/src/project_search.rs 🔗

@@ -12,15 +12,13 @@ use editor::{
     SelectAll, MAX_TAB_TITLE_LEN,
 };
 use futures::StreamExt;
-
-use gpui::platform::PromptLevel;
-
 use gpui::{
-    actions, elements::*, platform::MouseButton, Action, AnyElement, AnyViewHandle, AppContext,
-    Entity, ModelContext, ModelHandle, Subscription, Task, View, ViewContext, ViewHandle,
-    WeakModelHandle, WeakViewHandle,
+    actions,
+    elements::*,
+    platform::{MouseButton, PromptLevel},
+    Action, AnyElement, AnyViewHandle, AppContext, Entity, ModelContext, ModelHandle, Subscription,
+    Task, View, ViewContext, ViewHandle, WeakModelHandle, WeakViewHandle,
 };
-
 use menu::Confirm;
 use postage::stream::Stream;
 use project::{
@@ -132,8 +130,7 @@ pub struct ProjectSearchView {
 }
 
 struct SemanticSearchState {
-    file_count: usize,
-    outstanding_file_count: usize,
+    pending_file_count: usize,
     _progress_task: Task<()>,
 }
 
@@ -319,12 +316,8 @@ impl View for ProjectSearchView {
             };
 
             let semantic_status = if let Some(semantic) = &self.semantic_state {
-                if semantic.outstanding_file_count > 0 {
-                    format!(
-                        "Indexing: {} of {}...",
-                        semantic.file_count - semantic.outstanding_file_count,
-                        semantic.file_count
-                    )
+                if semantic.pending_file_count > 0 {
+                    format!("Remaining files to index: {}", semantic.pending_file_count)
                 } else {
                     "Indexing complete".to_string()
                 }
@@ -641,26 +634,27 @@ impl ProjectSearchView {
 
             let project = self.model.read(cx).project.clone();
 
-            let index_task = semantic_index.update(cx, |semantic_index, cx| {
-                semantic_index.index_project(project, cx)
+            let mut pending_file_count_rx = semantic_index.update(cx, |semantic_index, cx| {
+                semantic_index
+                    .index_project(project.clone(), cx)
+                    .detach_and_log_err(cx);
+                semantic_index.pending_file_count(&project).unwrap()
             });
 
             cx.spawn(|search_view, mut cx| async move {
-                let (files_to_index, mut files_remaining_rx) = index_task.await?;
-
                 search_view.update(&mut cx, |search_view, cx| {
                     cx.notify();
+                    let pending_file_count = *pending_file_count_rx.borrow();
                     search_view.semantic_state = Some(SemanticSearchState {
-                        file_count: files_to_index,
-                        outstanding_file_count: files_to_index,
+                        pending_file_count,
                         _progress_task: cx.spawn(|search_view, mut cx| async move {
-                            while let Some(count) = files_remaining_rx.recv().await {
+                            while let Some(count) = pending_file_count_rx.recv().await {
                                 search_view
                                     .update(&mut cx, |search_view, cx| {
                                         if let Some(semantic_search_state) =
                                             &mut search_view.semantic_state
                                         {
-                                            semantic_search_state.outstanding_file_count = count;
+                                            semantic_search_state.pending_file_count = count;
                                             cx.notify();
                                             if count == 0 {
                                                 return;
@@ -959,7 +953,7 @@ impl ProjectSearchView {
         match mode {
             SearchMode::Semantic => {
                 if let Some(semantic) = &mut self.semantic_state {
-                    if semantic.outstanding_file_count > 0 {
+                    if semantic.pending_file_count > 0 {
                         return;
                     }
 

crates/semantic_index/Cargo.toml 🔗

@@ -9,6 +9,7 @@ path = "src/semantic_index.rs"
 doctest = false
 
 [dependencies]
+collections = { path = "../collections" }
 gpui = { path = "../gpui" }
 language = { path = "../language" }
 project = { path = "../project" }
@@ -42,6 +43,7 @@ sha1 = "0.10.5"
 parse_duration = "2.1.1"
 
 [dev-dependencies]
+collections = { path = "../collections", features = ["test-support"] }
 gpui = { path = "../gpui", features = ["test-support"] }
 language = { path = "../language", features = ["test-support"] }
 project = { path = "../project", features = ["test-support"] }

crates/semantic_index/src/db.rs 🔗

@@ -1,9 +1,10 @@
 use crate::{
     embedding::Embedding,
-    parsing::{Document, DocumentDigest},
+    parsing::{Span, SpanDigest},
     SEMANTIC_INDEX_VERSION,
 };
 use anyhow::{anyhow, Context, Result};
+use collections::HashMap;
 use futures::channel::oneshot;
 use gpui::executor;
 use project::{search::PathMatcher, Fs};
@@ -12,7 +13,6 @@ use rusqlite::params;
 use rusqlite::types::Value;
 use std::{
     cmp::Ordering,
-    collections::HashMap,
     future::Future,
     ops::Range,
     path::{Path, PathBuf},
@@ -124,8 +124,12 @@ impl VectorDatabase {
             }
 
             log::trace!("vector database schema out of date. updating...");
+            // We renamed the `documents` table to `spans`, so we want to drop
+            // `documents` without recreating it if it exists.
             db.execute("DROP TABLE IF EXISTS documents", [])
                 .context("failed to drop 'documents' table")?;
+            db.execute("DROP TABLE IF EXISTS spans", [])
+                .context("failed to drop 'spans' table")?;
             db.execute("DROP TABLE IF EXISTS files", [])
                 .context("failed to drop 'files' table")?;
             db.execute("DROP TABLE IF EXISTS worktrees", [])
@@ -174,7 +178,7 @@ impl VectorDatabase {
             )?;
 
             db.execute(
-                "CREATE TABLE documents (
+                "CREATE TABLE spans (
                     id INTEGER PRIMARY KEY AUTOINCREMENT,
                     file_id INTEGER NOT NULL,
                     start_byte INTEGER NOT NULL,
@@ -195,7 +199,7 @@ impl VectorDatabase {
     pub fn delete_file(
         &self,
         worktree_id: i64,
-        delete_path: PathBuf,
+        delete_path: Arc<Path>,
     ) -> impl Future<Output = Result<()>> {
         self.transact(move |db| {
             db.execute(
@@ -209,9 +213,9 @@ impl VectorDatabase {
     pub fn insert_file(
         &self,
         worktree_id: i64,
-        path: PathBuf,
+        path: Arc<Path>,
         mtime: SystemTime,
-        documents: Vec<Document>,
+        spans: Vec<Span>,
     ) -> impl Future<Output = Result<()>> {
         self.transact(move |db| {
             // Return the existing ID, if both the file and mtime match
@@ -231,7 +235,7 @@ impl VectorDatabase {
             let t0 = Instant::now();
             let mut query = db.prepare(
                 "
-                INSERT INTO documents
+                INSERT INTO spans
                 (file_id, start_byte, end_byte, name, embedding, digest)
                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
                 ",
@@ -241,14 +245,14 @@ impl VectorDatabase {
                 t0.elapsed().as_millis()
             );
 
-            for document in documents {
+            for span in spans {
                 query.execute(params![
                     file_id,
-                    document.range.start.to_string(),
-                    document.range.end.to_string(),
-                    document.name,
-                    document.embedding,
-                    document.digest
+                    span.range.start.to_string(),
+                    span.range.end.to_string(),
+                    span.name,
+                    span.embedding,
+                    span.digest
                 ])?;
             }
 
@@ -278,17 +282,17 @@ impl VectorDatabase {
     pub fn embeddings_for_files(
         &self,
         worktree_id_file_paths: HashMap<i64, Vec<Arc<Path>>>,
-    ) -> impl Future<Output = Result<HashMap<DocumentDigest, Embedding>>> {
+    ) -> impl Future<Output = Result<HashMap<SpanDigest, Embedding>>> {
         self.transact(move |db| {
             let mut query = db.prepare(
                 "
                 SELECT digest, embedding
-                FROM documents
-                LEFT JOIN files ON files.id = documents.file_id
+                FROM spans
+                LEFT JOIN files ON files.id = spans.file_id
                 WHERE files.worktree_id = ? AND files.relative_path IN rarray(?)
             ",
             )?;
-            let mut embeddings_by_digest = HashMap::new();
+            let mut embeddings_by_digest = HashMap::default();
             for (worktree_id, file_paths) in worktree_id_file_paths {
                 let file_paths = Rc::new(
                     file_paths
@@ -297,10 +301,7 @@ impl VectorDatabase {
                         .collect::<Vec<_>>(),
                 );
                 let rows = query.query_map(params![worktree_id, file_paths], |row| {
-                    Ok((
-                        row.get::<_, DocumentDigest>(0)?,
-                        row.get::<_, Embedding>(1)?,
-                    ))
+                    Ok((row.get::<_, SpanDigest>(0)?, row.get::<_, Embedding>(1)?))
                 })?;
 
                 for row in rows {
@@ -316,7 +317,7 @@ impl VectorDatabase {
 
     pub fn find_or_create_worktree(
         &self,
-        worktree_root_path: PathBuf,
+        worktree_root_path: Arc<Path>,
     ) -> impl Future<Output = Result<i64>> {
         self.transact(move |db| {
             let mut worktree_query =
@@ -351,7 +352,7 @@ impl VectorDatabase {
                 WHERE worktree_id = ?1
                 ORDER BY relative_path",
             )?;
-            let mut result: HashMap<PathBuf, SystemTime> = HashMap::new();
+            let mut result: HashMap<PathBuf, SystemTime> = HashMap::default();
             for row in statement.query_map(params![worktree_id], |row| {
                 Ok((
                     row.get::<_, String>(0)?.into(),
@@ -379,7 +380,7 @@ impl VectorDatabase {
         let file_ids = file_ids.to_vec();
         self.transact(move |db| {
             let mut results = Vec::<(i64, f32)>::with_capacity(limit + 1);
-            Self::for_each_document(db, &file_ids, |id, embedding| {
+            Self::for_each_span(db, &file_ids, |id, embedding| {
                 let similarity = embedding.similarity(&query_embedding);
                 let ix = match results.binary_search_by(|(_, s)| {
                     similarity.partial_cmp(&s).unwrap_or(Ordering::Equal)
@@ -434,7 +435,7 @@ impl VectorDatabase {
         })
     }
 
-    fn for_each_document(
+    fn for_each_span(
         db: &rusqlite::Connection,
         file_ids: &[i64],
         mut f: impl FnMut(i64, Embedding),
@@ -444,7 +445,7 @@ impl VectorDatabase {
             SELECT
                 id, embedding
             FROM
-                documents
+                spans
             WHERE
                 file_id IN rarray(?)
             ",
@@ -459,7 +460,7 @@ impl VectorDatabase {
         Ok(())
     }
 
-    pub fn get_documents_by_ids(
+    pub fn spans_for_ids(
         &self,
         ids: &[i64],
     ) -> impl Future<Output = Result<Vec<(i64, PathBuf, Range<usize>)>>> {
@@ -468,16 +469,16 @@ impl VectorDatabase {
             let mut statement = db.prepare(
                 "
                     SELECT
-                        documents.id,
+                        spans.id,
                         files.worktree_id,
                         files.relative_path,
-                        documents.start_byte,
-                        documents.end_byte
+                        spans.start_byte,
+                        spans.end_byte
                     FROM
-                        documents, files
+                        spans, files
                     WHERE
-                        documents.file_id = files.id AND
-                        documents.id in rarray(?)
+                        spans.file_id = files.id AND
+                        spans.id in rarray(?)
                 ",
             )?;
 
@@ -500,7 +501,7 @@ impl VectorDatabase {
             for id in &ids {
                 let value = values_by_id
                     .remove(id)
-                    .ok_or(anyhow!("missing document id {}", id))?;
+                    .ok_or(anyhow!("missing span id {}", id))?;
                 results.push(value);
             }
 

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -1,15 +1,15 @@
-use crate::{embedding::EmbeddingProvider, parsing::Document, JobHandle};
+use crate::{embedding::EmbeddingProvider, parsing::Span, JobHandle};
 use gpui::executor::Background;
 use parking_lot::Mutex;
 use smol::channel;
-use std::{mem, ops::Range, path::PathBuf, sync::Arc, time::SystemTime};
+use std::{mem, ops::Range, path::Path, sync::Arc, time::SystemTime};
 
 #[derive(Clone)]
 pub struct FileToEmbed {
     pub worktree_id: i64,
-    pub path: PathBuf,
+    pub path: Arc<Path>,
     pub mtime: SystemTime,
-    pub documents: Vec<Document>,
+    pub spans: Vec<Span>,
     pub job_handle: JobHandle,
 }
 
@@ -19,7 +19,7 @@ impl std::fmt::Debug for FileToEmbed {
             .field("worktree_id", &self.worktree_id)
             .field("path", &self.path)
             .field("mtime", &self.mtime)
-            .field("document", &self.documents)
+            .field("spans", &self.spans)
             .finish_non_exhaustive()
     }
 }
@@ -29,13 +29,13 @@ impl PartialEq for FileToEmbed {
         self.worktree_id == other.worktree_id
             && self.path == other.path
             && self.mtime == other.mtime
-            && self.documents == other.documents
+            && self.spans == other.spans
     }
 }
 
 pub struct EmbeddingQueue {
     embedding_provider: Arc<dyn EmbeddingProvider>,
-    pending_batch: Vec<FileToEmbedFragment>,
+    pending_batch: Vec<FileFragmentToEmbed>,
     executor: Arc<Background>,
     pending_batch_token_count: usize,
     finished_files_tx: channel::Sender<FileToEmbed>,
@@ -43,9 +43,9 @@ pub struct EmbeddingQueue {
 }
 
 #[derive(Clone)]
-pub struct FileToEmbedFragment {
+pub struct FileFragmentToEmbed {
     file: Arc<Mutex<FileToEmbed>>,
-    document_range: Range<usize>,
+    span_range: Range<usize>,
 }
 
 impl EmbeddingQueue {
@@ -62,43 +62,40 @@ impl EmbeddingQueue {
     }
 
     pub fn push(&mut self, file: FileToEmbed) {
-        if file.documents.is_empty() {
+        if file.spans.is_empty() {
             self.finished_files_tx.try_send(file).unwrap();
             return;
         }
 
         let file = Arc::new(Mutex::new(file));
 
-        self.pending_batch.push(FileToEmbedFragment {
+        self.pending_batch.push(FileFragmentToEmbed {
             file: file.clone(),
-            document_range: 0..0,
+            span_range: 0..0,
         });
 
-        let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
-        let mut saved_tokens = 0;
-        for (ix, document) in file.lock().documents.iter().enumerate() {
-            let document_token_count = if document.embedding.is_none() {
-                document.token_count
+        let mut fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
+        for (ix, span) in file.lock().spans.iter().enumerate() {
+            let span_token_count = if span.embedding.is_none() {
+                span.token_count
             } else {
-                saved_tokens += document.token_count;
                 0
             };
 
-            let next_token_count = self.pending_batch_token_count + document_token_count;
+            let next_token_count = self.pending_batch_token_count + span_token_count;
             if next_token_count > self.embedding_provider.max_tokens_per_batch() {
                 let range_end = fragment_range.end;
                 self.flush();
-                self.pending_batch.push(FileToEmbedFragment {
+                self.pending_batch.push(FileFragmentToEmbed {
                     file: file.clone(),
-                    document_range: range_end..range_end,
+                    span_range: range_end..range_end,
                 });
-                fragment_range = &mut self.pending_batch.last_mut().unwrap().document_range;
+                fragment_range = &mut self.pending_batch.last_mut().unwrap().span_range;
             }
 
             fragment_range.end = ix + 1;
-            self.pending_batch_token_count += document_token_count;
+            self.pending_batch_token_count += span_token_count;
         }
-        log::trace!("Saved Tokens: {:?}", saved_tokens);
     }
 
     pub fn flush(&mut self) {
@@ -111,60 +108,55 @@ impl EmbeddingQueue {
         let finished_files_tx = self.finished_files_tx.clone();
         let embedding_provider = self.embedding_provider.clone();
 
-        self.executor.spawn(async move {
-            let mut spans = Vec::new();
-            let mut document_count = 0;
-            for fragment in &batch {
-                let file = fragment.file.lock();
-                document_count += file.documents[fragment.document_range.clone()].len();
-                spans.extend(
-                    {
-                        file.documents[fragment.document_range.clone()]
-                            .iter().filter(|d| d.embedding.is_none())
-                            .map(|d| d.content.clone())
-                        }
-                );
-            }
-
-            log::trace!("Documents Length: {:?}", document_count);
-            log::trace!("Span Length: {:?}", spans.clone().len());
-
-            // If spans is 0, just send the fragment to the finished files if its the last one.
-            if spans.len() == 0 {
-                for fragment in batch.clone() {
-                    if let Some(file) = Arc::into_inner(fragment.file) {
-                        finished_files_tx.try_send(file.into_inner()).unwrap();
-                    }
+        self.executor
+            .spawn(async move {
+                let mut spans = Vec::new();
+                for fragment in &batch {
+                    let file = fragment.file.lock();
+                    spans.extend(
+                        file.spans[fragment.span_range.clone()]
+                            .iter()
+                            .filter(|d| d.embedding.is_none())
+                            .map(|d| d.content.clone()),
+                    );
                 }
-                return;
-            };
-
-            match embedding_provider.embed_batch(spans).await {
-                Ok(embeddings) => {
-                    let mut embeddings = embeddings.into_iter();
-                    for fragment in batch {
-                        for document in
-                            &mut fragment.file.lock().documents[fragment.document_range.clone()].iter_mut().filter(|d| d.embedding.is_none())
-                        {
-                            if let Some(embedding) = embeddings.next() {
-                                document.embedding = Some(embedding);
-                            } else {
-                                //
-                                log::error!("number of embeddings returned different from number of documents");
-                            }
-                        }
 
+                // If spans is 0, just send the fragment to the finished files if its the last one.
+                if spans.is_empty() {
+                    for fragment in batch.clone() {
                         if let Some(file) = Arc::into_inner(fragment.file) {
                             finished_files_tx.try_send(file.into_inner()).unwrap();
                         }
                     }
+                    return;
+                };
+
+                match embedding_provider.embed_batch(spans).await {
+                    Ok(embeddings) => {
+                        let mut embeddings = embeddings.into_iter();
+                        for fragment in batch {
+                            for span in &mut fragment.file.lock().spans[fragment.span_range.clone()]
+                                .iter_mut()
+                                .filter(|d| d.embedding.is_none())
+                            {
+                                if let Some(embedding) = embeddings.next() {
+                                    span.embedding = Some(embedding);
+                                } else {
+                                    log::error!("number of embeddings != number of documents");
+                                }
+                            }
+
+                            if let Some(file) = Arc::into_inner(fragment.file) {
+                                finished_files_tx.try_send(file.into_inner()).unwrap();
+                            }
+                        }
+                    }
+                    Err(error) => {
+                        log::error!("{:?}", error);
+                    }
                 }
-                Err(error) => {
-                    log::error!("{:?}", error);
-                }
-            }
-        })
-        .detach();
+            })
+            .detach();
     }
 
     pub fn finished_files(&self) -> channel::Receiver<FileToEmbed> {

crates/semantic_index/src/parsing.rs 🔗

@@ -16,9 +16,9 @@ use std::{
 use tree_sitter::{Parser, QueryCursor};
 
 #[derive(Debug, PartialEq, Eq, Clone, Hash)]
-pub struct DocumentDigest([u8; 20]);
+pub struct SpanDigest([u8; 20]);
 
-impl FromSql for DocumentDigest {
+impl FromSql for SpanDigest {
     fn column_result(value: ValueRef) -> FromSqlResult<Self> {
         let blob = value.as_blob()?;
         let bytes =
@@ -27,17 +27,17 @@ impl FromSql for DocumentDigest {
                     expected_size: 20,
                     blob_size: blob.len(),
                 })?;
-        return Ok(DocumentDigest(bytes));
+        return Ok(SpanDigest(bytes));
     }
 }
 
-impl ToSql for DocumentDigest {
+impl ToSql for SpanDigest {
     fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
         self.0.to_sql()
     }
 }
 
-impl From<&'_ str> for DocumentDigest {
+impl From<&'_ str> for SpanDigest {
     fn from(value: &'_ str) -> Self {
         let mut sha1 = Sha1::new();
         sha1.update(value);
@@ -46,12 +46,12 @@ impl From<&'_ str> for DocumentDigest {
 }
 
 #[derive(Debug, PartialEq, Clone)]
-pub struct Document {
+pub struct Span {
     pub name: String,
     pub range: Range<usize>,
     pub content: String,
     pub embedding: Option<Embedding>,
-    pub digest: DocumentDigest,
+    pub digest: SpanDigest,
     pub token_count: usize,
 }
 
@@ -97,14 +97,14 @@ impl CodeContextRetriever {
         relative_path: &Path,
         language_name: Arc<str>,
         content: &str,
-    ) -> Result<Vec<Document>> {
+    ) -> Result<Vec<Span>> {
         let document_span = ENTIRE_FILE_TEMPLATE
             .replace("<path>", relative_path.to_string_lossy().as_ref())
             .replace("<language>", language_name.as_ref())
             .replace("<item>", &content);
-        let digest = DocumentDigest::from(document_span.as_str());
+        let digest = SpanDigest::from(document_span.as_str());
         let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
-        Ok(vec![Document {
+        Ok(vec![Span {
             range: 0..content.len(),
             content: document_span,
             embedding: Default::default(),
@@ -114,13 +114,13 @@ impl CodeContextRetriever {
         }])
     }
 
-    fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Document>> {
+    fn parse_markdown_file(&self, relative_path: &Path, content: &str) -> Result<Vec<Span>> {
         let document_span = MARKDOWN_CONTEXT_TEMPLATE
             .replace("<path>", relative_path.to_string_lossy().as_ref())
             .replace("<item>", &content);
-        let digest = DocumentDigest::from(document_span.as_str());
+        let digest = SpanDigest::from(document_span.as_str());
         let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
-        Ok(vec![Document {
+        Ok(vec![Span {
             range: 0..content.len(),
             content: document_span,
             embedding: None,
@@ -191,32 +191,32 @@ impl CodeContextRetriever {
         relative_path: &Path,
         content: &str,
         language: Arc<Language>,
-    ) -> Result<Vec<Document>> {
+    ) -> Result<Vec<Span>> {
         let language_name = language.name();
 
         if PARSEABLE_ENTIRE_FILE_TYPES.contains(&language_name.as_ref()) {
             return self.parse_entire_file(relative_path, language_name, &content);
-        } else if &language_name.to_string() == &"Markdown".to_string() {
+        } else if language_name.as_ref() == "Markdown" {
             return self.parse_markdown_file(relative_path, &content);
         }
 
-        let mut documents = self.parse_file(content, language)?;
-        for document in &mut documents {
+        let mut spans = self.parse_file(content, language)?;
+        for span in &mut spans {
             let document_content = CODE_CONTEXT_TEMPLATE
                 .replace("<path>", relative_path.to_string_lossy().as_ref())
                 .replace("<language>", language_name.as_ref())
-                .replace("item", &document.content);
+                .replace("item", &span.content);
 
             let (document_content, token_count) =
                 self.embedding_provider.truncate(&document_content);
 
-            document.content = document_content;
-            document.token_count = token_count;
+            span.content = document_content;
+            span.token_count = token_count;
         }
-        Ok(documents)
+        Ok(spans)
     }
 
-    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Document>> {
+    pub fn parse_file(&mut self, content: &str, language: Arc<Language>) -> Result<Vec<Span>> {
         let grammar = language
             .grammar()
             .ok_or_else(|| anyhow!("no grammar for language"))?;
@@ -227,7 +227,7 @@ impl CodeContextRetriever {
         let language_scope = language.default_scope();
         let placeholder = language_scope.collapsed_placeholder();
 
-        let mut documents = Vec::new();
+        let mut spans = Vec::new();
         let mut collapsed_ranges_within = Vec::new();
         let mut parsed_name_ranges = HashSet::new();
         for (i, context_match) in matches.iter().enumerate() {
@@ -267,22 +267,22 @@ impl CodeContextRetriever {
 
             collapsed_ranges_within.sort_by_key(|r| (r.start, Reverse(r.end)));
 
-            let mut document_content = String::new();
+            let mut span_content = String::new();
             for context_range in &context_match.context_ranges {
                 add_content_from_range(
-                    &mut document_content,
+                    &mut span_content,
                     content,
                     context_range.clone(),
                     context_match.start_col,
                 );
-                document_content.push_str("\n");
+                span_content.push_str("\n");
             }
 
             let mut offset = item_range.start;
             for collapsed_range in &collapsed_ranges_within {
                 if collapsed_range.start > offset {
                     add_content_from_range(
-                        &mut document_content,
+                        &mut span_content,
                         content,
                         offset..collapsed_range.start,
                         context_match.start_col,
@@ -291,24 +291,24 @@ impl CodeContextRetriever {
                 }
 
                 if collapsed_range.end > offset {
-                    document_content.push_str(placeholder);
+                    span_content.push_str(placeholder);
                     offset = collapsed_range.end;
                 }
             }
 
             if offset < item_range.end {
                 add_content_from_range(
-                    &mut document_content,
+                    &mut span_content,
                     content,
                     offset..item_range.end,
                     context_match.start_col,
                 );
             }
 
-            let sha1 = DocumentDigest::from(document_content.as_str());
-            documents.push(Document {
+            let sha1 = SpanDigest::from(span_content.as_str());
+            spans.push(Span {
                 name,
-                content: document_content,
+                content: span_content,
                 range: item_range.clone(),
                 embedding: None,
                 digest: sha1,
@@ -316,7 +316,7 @@ impl CodeContextRetriever {
             })
         }
 
-        return Ok(documents);
+        return Ok(spans);
     }
 }
 

crates/semantic_index/src/semantic_index.rs 🔗

@@ -9,22 +9,21 @@ mod semantic_index_tests;
 
 use crate::semantic_index_settings::SemanticIndexSettings;
 use anyhow::{anyhow, Result};
+use collections::{BTreeMap, HashMap, HashSet};
 use db::VectorDatabase;
 use embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
 use embedding_queue::{EmbeddingQueue, FileToEmbed};
-use futures::{FutureExt, StreamExt};
+use futures::{future, FutureExt, StreamExt};
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task, WeakModelHandle};
-use language::{Anchor, Buffer, Language, LanguageRegistry};
+use language::{Anchor, Bias, Buffer, Language, LanguageRegistry};
 use parking_lot::Mutex;
-use parsing::{CodeContextRetriever, DocumentDigest, PARSEABLE_ENTIRE_FILE_TYPES};
+use parsing::{CodeContextRetriever, SpanDigest, PARSEABLE_ENTIRE_FILE_TYPES};
 use postage::watch;
-use project::{
-    search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId,
-};
+use project::{search::PathMatcher, Fs, PathChange, Project, ProjectEntryId, Worktree, WorktreeId};
 use smol::channel;
 use std::{
     cmp::Ordering,
-    collections::{BTreeMap, HashMap},
+    future::Future,
     ops::Range,
     path::{Path, PathBuf},
     sync::{Arc, Weak},
@@ -36,9 +35,8 @@ use util::{
     paths::EMBEDDINGS_DIR,
     ResultExt,
 };
-use workspace::WorkspaceCreated;
 
-const SEMANTIC_INDEX_VERSION: usize = 9;
+const SEMANTIC_INDEX_VERSION: usize = 10;
 const BACKGROUND_INDEXING_DELAY: Duration = Duration::from_secs(5 * 60);
 const EMBEDDING_QUEUE_FLUSH_TIMEOUT: Duration = Duration::from_millis(250);
 
@@ -59,24 +57,6 @@ pub fn init(
         return;
     }
 
-    cx.subscribe_global::<WorkspaceCreated, _>({
-        move |event, cx| {
-            let Some(semantic_index) = SemanticIndex::global(cx) else {
-                return;
-            };
-            let workspace = &event.0;
-            if let Some(workspace) = workspace.upgrade(cx) {
-                let project = workspace.read(cx).project().clone();
-                if project.read(cx).is_local() {
-                    semantic_index.update(cx, |index, cx| {
-                        index.initialize_project(project, cx).detach_and_log_err(cx)
-                    });
-                }
-            }
-        }
-    })
-    .detach();
-
     cx.spawn(move |mut cx| async move {
         let semantic_index = SemanticIndex::new(
             fs,
@@ -104,22 +84,78 @@ pub struct SemanticIndex {
     db: VectorDatabase,
     embedding_provider: Arc<dyn EmbeddingProvider>,
     language_registry: Arc<LanguageRegistry>,
-    parsing_files_tx: channel::Sender<(Arc<HashMap<DocumentDigest, Embedding>>, PendingFile)>,
+    parsing_files_tx: channel::Sender<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>,
     _embedding_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
 }
 
 struct ProjectState {
-    worktree_db_ids: Vec<(WorktreeId, i64)>,
+    worktrees: HashMap<WorktreeId, WorktreeState>,
+    pending_file_count_rx: watch::Receiver<usize>,
+    pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
     _subscription: gpui::Subscription,
-    outstanding_job_count_rx: watch::Receiver<usize>,
-    outstanding_job_count_tx: Arc<Mutex<watch::Sender<usize>>>,
-    changed_paths: BTreeMap<ProjectPath, ChangedPathInfo>,
+}
+
+enum WorktreeState {
+    Registering(RegisteringWorktreeState),
+    Registered(RegisteredWorktreeState),
+}
+
+impl WorktreeState {
+    fn paths_changed(
+        &mut self,
+        changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
+        worktree: &Worktree,
+    ) {
+        let changed_paths = match self {
+            Self::Registering(state) => &mut state.changed_paths,
+            Self::Registered(state) => &mut state.changed_paths,
+        };
+
+        for (path, entry_id, change) in changes.iter() {
+            let Some(entry) = worktree.entry_for_id(*entry_id) else {
+                continue;
+            };
+            if entry.is_ignored || entry.is_symlink || entry.is_external || entry.is_dir() {
+                continue;
+            }
+            changed_paths.insert(
+                path.clone(),
+                ChangedPathInfo {
+                    mtime: entry.mtime,
+                    is_deleted: *change == PathChange::Removed,
+                },
+            );
+        }
+    }
+}
+
+struct RegisteringWorktreeState {
+    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
+    done_rx: watch::Receiver<Option<()>>,
+    _registration: Task<()>,
+}
+
+impl RegisteringWorktreeState {
+    fn done(&self) -> impl Future<Output = ()> {
+        let mut done_rx = self.done_rx.clone();
+        async move {
+            while let Some(result) = done_rx.next().await {
+                if result.is_some() {
+                    break;
+                }
+            }
+        }
+    }
+}
+
+struct RegisteredWorktreeState {
+    db_id: i64,
+    changed_paths: BTreeMap<Arc<Path>, ChangedPathInfo>,
 }
 
 struct ChangedPathInfo {
-    changed_at: Instant,
     mtime: SystemTime,
     is_deleted: bool,
 }
@@ -141,47 +177,23 @@ impl JobHandle {
 }
 
 impl ProjectState {
-    fn new(
-        subscription: gpui::Subscription,
-        worktree_db_ids: Vec<(WorktreeId, i64)>,
-        changed_paths: BTreeMap<ProjectPath, ChangedPathInfo>,
-    ) -> Self {
-        let (outstanding_job_count_tx, outstanding_job_count_rx) = watch::channel_with(0);
-        let outstanding_job_count_tx = Arc::new(Mutex::new(outstanding_job_count_tx));
+    fn new(subscription: gpui::Subscription) -> Self {
+        let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
+        let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
         Self {
-            worktree_db_ids,
-            outstanding_job_count_rx,
-            outstanding_job_count_tx,
-            changed_paths,
+            worktrees: Default::default(),
+            pending_file_count_rx,
+            pending_file_count_tx,
             _subscription: subscription,
         }
     }
 
-    pub fn get_outstanding_count(&self) -> usize {
-        self.outstanding_job_count_rx.borrow().clone()
-    }
-
-    fn db_id_for_worktree_id(&self, id: WorktreeId) -> Option<i64> {
-        self.worktree_db_ids
-            .iter()
-            .find_map(|(worktree_id, db_id)| {
-                if *worktree_id == id {
-                    Some(*db_id)
-                } else {
-                    None
-                }
-            })
-    }
-
     fn worktree_id_for_db_id(&self, id: i64) -> Option<WorktreeId> {
-        self.worktree_db_ids
+        self.worktrees
             .iter()
-            .find_map(|(worktree_id, db_id)| {
-                if *db_id == id {
-                    Some(*worktree_id)
-                } else {
-                    None
-                }
+            .find_map(|(worktree_id, worktree_state)| match worktree_state {
+                WorktreeState::Registered(state) if state.db_id == id => Some(*worktree_id),
+                _ => None,
             })
     }
 }
@@ -189,7 +201,7 @@ impl ProjectState {
 #[derive(Clone)]
 pub struct PendingFile {
     worktree_db_id: i64,
-    relative_path: PathBuf,
+    relative_path: Arc<Path>,
     absolute_path: PathBuf,
     language: Option<Arc<Language>>,
     modified_time: SystemTime,
@@ -240,16 +252,16 @@ impl SemanticIndex {
                 let db = db.clone();
                 async move {
                     while let Ok(file) = embedded_files.recv().await {
-                        db.insert_file(file.worktree_id, file.path, file.mtime, file.documents)
+                        db.insert_file(file.worktree_id, file.path, file.mtime, file.spans)
                             .await
                             .log_err();
                     }
                 }
             });
 
-            // Parse files into embeddable documents.
+            // Parse files into embeddable spans.
             let (parsing_files_tx, parsing_files_rx) =
-                channel::unbounded::<(Arc<HashMap<DocumentDigest, Embedding>>, PendingFile)>();
+                channel::unbounded::<(Arc<HashMap<SpanDigest, Embedding>>, PendingFile)>();
             let embedding_queue = Arc::new(Mutex::new(embedding_queue));
             let mut _parsing_files_tasks = Vec::new();
             for _ in 0..cx.background().num_cpus() {
@@ -298,7 +310,7 @@ impl SemanticIndex {
                 parsing_files_tx,
                 _embedding_task,
                 _parsing_files_tasks,
-                projects: HashMap::new(),
+                projects: Default::default(),
             }
         }))
     }
@@ -308,26 +320,26 @@ impl SemanticIndex {
         pending_file: PendingFile,
         retriever: &mut CodeContextRetriever,
         embedding_queue: &Arc<Mutex<EmbeddingQueue>>,
-        embeddings_for_digest: &HashMap<DocumentDigest, Embedding>,
+        embeddings_for_digest: &HashMap<SpanDigest, Embedding>,
     ) {
         let Some(language) = pending_file.language else {
             return;
         };
 
         if let Some(content) = fs.load(&pending_file.absolute_path).await.log_err() {
-            if let Some(mut documents) = retriever
+            if let Some(mut spans) = retriever
                 .parse_file_with_template(&pending_file.relative_path, &content, language)
                 .log_err()
             {
                 log::trace!(
-                    "parsed path {:?}: {} documents",
+                    "parsed path {:?}: {} spans",
                     pending_file.relative_path,
-                    documents.len()
+                    spans.len()
                 );
 
-                for document in documents.iter_mut() {
-                    if let Some(embedding) = embeddings_for_digest.get(&document.digest) {
-                        document.embedding = Some(embedding.to_owned());
+                for span in &mut spans {
+                    if let Some(embedding) = embeddings_for_digest.get(&span.digest) {
+                        span.embedding = Some(embedding.to_owned());
                     }
                 }
 
@@ -336,7 +348,7 @@ impl SemanticIndex {
                     path: pending_file.relative_path,
                     mtime: pending_file.modified_time,
                     job_handle: pending_file.job_handle,
-                    documents,
+                    spans: spans,
                 });
             }
         }
@@ -369,9 +381,9 @@ impl SemanticIndex {
     fn project_entries_changed(
         &mut self,
         project: ModelHandle<Project>,
+        worktree_id: WorktreeId,
         changes: Arc<[(Arc<Path>, ProjectEntryId, PathChange)]>,
-        cx: &mut ModelContext<'_, SemanticIndex>,
-        worktree_id: &WorktreeId,
+        cx: &mut ModelContext<Self>,
     ) {
         let Some(worktree) = project.read(cx).worktree_for_id(worktree_id.clone(), cx) else {
             return;
@@ -381,258 +393,219 @@ impl SemanticIndex {
             return;
         };
 
-        let embeddings_for_digest = {
-            let mut worktree_id_file_paths = HashMap::new();
-            for (path, _) in &project_state.changed_paths {
-                if let Some(worktree_db_id) = project_state.db_id_for_worktree_id(path.worktree_id)
-                {
-                    worktree_id_file_paths
-                        .entry(worktree_db_id)
-                        .or_insert(Vec::new())
-                        .push(path.path.clone());
-                }
-            }
-            self.db.embeddings_for_files(worktree_id_file_paths)
-        };
-
         let worktree = worktree.read(cx);
-        let change_time = Instant::now();
-        for (path, entry_id, change) in changes.iter() {
-            let Some(entry) = worktree.entry_for_id(*entry_id) else {
-                continue;
-            };
-            if entry.is_ignored || entry.is_symlink || entry.is_external {
-                continue;
-            }
-            let project_path = ProjectPath {
-                worktree_id: *worktree_id,
-                path: path.clone(),
+        let worktree_state =
+            if let Some(worktree_state) = project_state.worktrees.get_mut(&worktree_id) {
+                worktree_state
+            } else {
+                return;
             };
-            project_state.changed_paths.insert(
-                project_path,
-                ChangedPathInfo {
-                    changed_at: change_time,
-                    mtime: entry.mtime,
-                    is_deleted: *change == PathChange::Removed,
-                },
-            );
+        worktree_state.paths_changed(changes, worktree);
+        if let WorktreeState::Registered(_) = worktree_state {
+            cx.spawn_weak(|this, mut cx| async move {
+                cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
+                if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
+                    this.update(&mut cx, |this, cx| {
+                        this.index_project(project, cx).detach_and_log_err(cx)
+                    });
+                }
+            })
+            .detach();
         }
-
-        cx.spawn_weak(|this, mut cx| async move {
-            let embeddings_for_digest = embeddings_for_digest.await.log_err().unwrap_or_default();
-
-            cx.background().timer(BACKGROUND_INDEXING_DELAY).await;
-            if let Some((this, project)) = this.upgrade(&cx).zip(project.upgrade(&cx)) {
-                Self::reindex_changed_paths(
-                    this,
-                    project,
-                    Some(change_time),
-                    &mut cx,
-                    Arc::new(embeddings_for_digest),
-                )
-                .await;
-            }
-        })
-        .detach();
     }
 
-    pub fn initialize_project(
+    fn register_worktree(
         &mut self,
         project: ModelHandle<Project>,
+        worktree: ModelHandle<Worktree>,
         cx: &mut ModelContext<Self>,
-    ) -> Task<Result<()>> {
-        log::trace!("Initializing Project for Semantic Index");
-        let worktree_scans_complete = project
-            .read(cx)
-            .worktrees(cx)
-            .map(|worktree| {
-                let scan_complete = worktree.read(cx).as_local().unwrap().scan_complete();
-                async move {
-                    scan_complete.await;
-                }
-            })
-            .collect::<Vec<_>>();
-
-        let worktree_db_ids = project
-            .read(cx)
-            .worktrees(cx)
-            .map(|worktree| {
-                self.db
-                    .find_or_create_worktree(worktree.read(cx).abs_path().to_path_buf())
-            })
-            .collect::<Vec<_>>();
-
-        let _subscription = cx.subscribe(&project, |this, project, event, cx| {
-            if let project::Event::WorktreeUpdatedEntries(worktree_id, changes) = event {
-                this.project_entries_changed(project.clone(), changes.clone(), cx, worktree_id);
-            };
-        });
-
+    ) {
+        let project = project.downgrade();
+        let project_state = if let Some(project_state) = self.projects.get_mut(&project) {
+            project_state
+        } else {
+            return;
+        };
+        let worktree = if let Some(worktree) = worktree.read(cx).as_local() {
+            worktree
+        } else {
+            return;
+        };
+        let worktree_abs_path = worktree.abs_path().clone();
+        let scan_complete = worktree.scan_complete();
+        let worktree_id = worktree.id();
+        let db = self.db.clone();
         let language_registry = self.language_registry.clone();
-
-        cx.spawn(|this, mut cx| async move {
-            futures::future::join_all(worktree_scans_complete).await;
-
-            let worktree_db_ids = futures::future::join_all(worktree_db_ids).await;
-            let worktrees = project.read_with(&cx, |project, cx| {
-                project
-                    .worktrees(cx)
-                    .map(|worktree| worktree.read(cx).snapshot())
-                    .collect::<Vec<_>>()
-            });
-
-            let mut worktree_file_mtimes = HashMap::new();
-            let mut db_ids_by_worktree_id = HashMap::new();
-
-            for (worktree, db_id) in worktrees.iter().zip(worktree_db_ids) {
-                let db_id = db_id?;
-                db_ids_by_worktree_id.insert(worktree.id(), db_id);
-                worktree_file_mtimes.insert(
-                    worktree.id(),
-                    this.read_with(&cx, |this, _| this.db.get_file_mtimes(db_id))
-                        .await?,
-                );
-            }
-
-            let worktree_db_ids = db_ids_by_worktree_id
-                .iter()
-                .map(|(a, b)| (*a, *b))
-                .collect();
-
-            let changed_paths = cx
-                .background()
-                .spawn(async move {
-                    let mut changed_paths = BTreeMap::new();
-                    let now = Instant::now();
-                    for worktree in worktrees.into_iter() {
-                        let mut file_mtimes = worktree_file_mtimes.remove(&worktree.id()).unwrap();
-                        for file in worktree.files(false, 0) {
-                            let absolute_path = worktree.absolutize(&file.path);
-
-                            if file.is_external || file.is_ignored || file.is_symlink {
-                                continue;
-                            }
-
-                            if let Ok(language) = language_registry
-                                .language_for_file(&absolute_path, None)
-                                .await
-                            {
-                                // Test if file is valid parseable file
-                                if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
-                                    && &language.name().as_ref() != &"Markdown"
-                                    && language
-                                        .grammar()
-                                        .and_then(|grammar| grammar.embedding_config.as_ref())
-                                        .is_none()
-                                {
+        let (mut done_tx, done_rx) = watch::channel();
+        let registration = cx.spawn(|this, mut cx| {
+            async move {
+                let register = async {
+                    scan_complete.await;
+                    let db_id = db.find_or_create_worktree(worktree_abs_path).await?;
+                    let mut file_mtimes = db.get_file_mtimes(db_id).await?;
+                    let worktree = if let Some(project) = project.upgrade(&cx) {
+                        project
+                            .read_with(&cx, |project, cx| project.worktree_for_id(worktree_id, cx))
+                            .ok_or_else(|| anyhow!("worktree not found"))?
+                    } else {
+                        return anyhow::Ok(());
+                    };
+                    let worktree = worktree.read_with(&cx, |worktree, _| worktree.snapshot());
+                    let mut changed_paths = cx
+                        .background()
+                        .spawn(async move {
+                            let mut changed_paths = BTreeMap::new();
+                            for file in worktree.files(false, 0) {
+                                let absolute_path = worktree.absolutize(&file.path);
+
+                                if file.is_external || file.is_ignored || file.is_symlink {
                                     continue;
                                 }
 
-                                let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
-                                let already_stored = stored_mtime
-                                    .map_or(false, |existing_mtime| existing_mtime == file.mtime);
-
-                                if !already_stored {
-                                    changed_paths.insert(
-                                        ProjectPath {
-                                            worktree_id: worktree.id(),
-                                            path: file.path.clone(),
-                                        },
-                                        ChangedPathInfo {
-                                            changed_at: now,
-                                            mtime: file.mtime,
-                                            is_deleted: false,
-                                        },
-                                    );
+                                if let Ok(language) = language_registry
+                                    .language_for_file(&absolute_path, None)
+                                    .await
+                                {
+                                    // Test if file is valid parseable file
+                                    if !PARSEABLE_ENTIRE_FILE_TYPES
+                                        .contains(&language.name().as_ref())
+                                        && &language.name().as_ref() != &"Markdown"
+                                        && language
+                                            .grammar()
+                                            .and_then(|grammar| grammar.embedding_config.as_ref())
+                                            .is_none()
+                                    {
+                                        continue;
+                                    }
+
+                                    let stored_mtime = file_mtimes.remove(&file.path.to_path_buf());
+                                    let already_stored = stored_mtime
+                                        .map_or(false, |existing_mtime| {
+                                            existing_mtime == file.mtime
+                                        });
+
+                                    if !already_stored {
+                                        changed_paths.insert(
+                                            file.path.clone(),
+                                            ChangedPathInfo {
+                                                mtime: file.mtime,
+                                                is_deleted: false,
+                                            },
+                                        );
+                                    }
                                 }
                             }
-                        }
 
-                        // Clean up entries from database that are no longer in the worktree.
-                        for (path, mtime) in file_mtimes {
-                            changed_paths.insert(
-                                ProjectPath {
-                                    worktree_id: worktree.id(),
-                                    path: path.into(),
-                                },
-                                ChangedPathInfo {
-                                    changed_at: now,
-                                    mtime,
-                                    is_deleted: true,
-                                },
-                            );
+                            // Clean up entries from database that are no longer in the worktree.
+                            for (path, mtime) in file_mtimes {
+                                changed_paths.insert(
+                                    path.into(),
+                                    ChangedPathInfo {
+                                        mtime,
+                                        is_deleted: true,
+                                    },
+                                );
+                            }
+
+                            anyhow::Ok(changed_paths)
+                        })
+                        .await?;
+                    this.update(&mut cx, |this, cx| {
+                        let project_state = this
+                            .projects
+                            .get_mut(&project)
+                            .ok_or_else(|| anyhow!("project not registered"))?;
+                        let project = project
+                            .upgrade(cx)
+                            .ok_or_else(|| anyhow!("project was dropped"))?;
+
+                        if let Some(WorktreeState::Registering(state)) =
+                            project_state.worktrees.remove(&worktree_id)
+                        {
+                            changed_paths.extend(state.changed_paths);
                         }
-                    }
+                        project_state.worktrees.insert(
+                            worktree_id,
+                            WorktreeState::Registered(RegisteredWorktreeState {
+                                db_id,
+                                changed_paths,
+                            }),
+                        );
+                        this.index_project(project, cx).detach_and_log_err(cx);
+
+                        anyhow::Ok(())
+                    })?;
+
+                    anyhow::Ok(())
+                };
 
-                    anyhow::Ok(changed_paths)
-                })
-                .await?;
+                if register.await.log_err().is_none() {
+                    // Stop tracking this worktree if the registration failed.
+                    this.update(&mut cx, |this, _| {
+                        this.projects.get_mut(&project).map(|project_state| {
+                            project_state.worktrees.remove(&worktree_id);
+                        });
+                    })
+                }
 
-            this.update(&mut cx, |this, _| {
-                this.projects.insert(
-                    project.downgrade(),
-                    ProjectState::new(_subscription, worktree_db_ids, changed_paths),
-                );
-            });
-            Result::<(), _>::Ok(())
-        })
+                *done_tx.borrow_mut() = Some(());
+            }
+        });
+        project_state.worktrees.insert(
+            worktree_id,
+            WorktreeState::Registering(RegisteringWorktreeState {
+                changed_paths: Default::default(),
+                done_rx,
+                _registration: registration,
+            }),
+        );
     }
 
-    pub fn index_project(
+    fn project_worktrees_changed(
         &mut self,
         project: ModelHandle<Project>,
         cx: &mut ModelContext<Self>,
-    ) -> Task<Result<(usize, watch::Receiver<usize>)>> {
-        cx.spawn(|this, mut cx| async move {
-            let embeddings_for_digest = this.read_with(&cx, |this, _| {
-                if let Some(state) = this.projects.get(&project.downgrade()) {
-                    let mut worktree_id_file_paths = HashMap::default();
-                    for (path, _) in &state.changed_paths {
-                        if let Some(worktree_db_id) = state.db_id_for_worktree_id(path.worktree_id)
-                        {
-                            worktree_id_file_paths
-                                .entry(worktree_db_id)
-                                .or_insert(Vec::new())
-                                .push(path.path.clone());
-                        }
-                    }
-
-                    Ok(this.db.embeddings_for_files(worktree_id_file_paths))
-                } else {
-                    Err(anyhow!("Project not yet initialized"))
-                }
-            })?;
-
-            let embeddings_for_digest = Arc::new(embeddings_for_digest.await?);
-
-            Self::reindex_changed_paths(
-                this.clone(),
-                project.clone(),
-                None,
-                &mut cx,
-                embeddings_for_digest,
-            )
-            .await;
+    ) {
+        let project_state = if let Some(project_state) = self.projects.get_mut(&project.downgrade())
+        {
+            project_state
+        } else {
+            return;
+        };
 
-            this.update(&mut cx, |this, _cx| {
-                let Some(state) = this.projects.get(&project.downgrade()) else {
-                    return Err(anyhow!("Project not yet initialized"));
-                };
-                let job_count_rx = state.outstanding_job_count_rx.clone();
-                let count = state.get_outstanding_count();
-                Ok((count, job_count_rx))
-            })
-        })
+        let mut worktrees = project
+            .read(cx)
+            .worktrees(cx)
+            .filter(|worktree| worktree.read(cx).is_local())
+            .collect::<Vec<_>>();
+        let worktree_ids = worktrees
+            .iter()
+            .map(|worktree| worktree.read(cx).id())
+            .collect::<HashSet<_>>();
+
+        // Remove worktrees that are no longer present
+        project_state
+            .worktrees
+            .retain(|worktree_id, _| worktree_ids.contains(worktree_id));
+
+        // Register new worktrees
+        worktrees.retain(|worktree| {
+            let worktree_id = worktree.read(cx).id();
+            !project_state.worktrees.contains_key(&worktree_id)
+        });
+        for worktree in worktrees {
+            self.register_worktree(project.clone(), worktree, cx);
+        }
     }
 
-    pub fn outstanding_job_count_rx(
+    pub fn pending_file_count(
         &self,
         project: &ModelHandle<Project>,
     ) -> Option<watch::Receiver<usize>> {
         Some(
             self.projects
                 .get(&project.downgrade())?
-                .outstanding_job_count_rx
+                .pending_file_count_rx
                 .clone(),
         )
     }
@@ -646,25 +619,13 @@ impl SemanticIndex {
         excludes: Vec<PathMatcher>,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<SearchResult>>> {
-        let project_state = if let Some(state) = self.projects.get(&project.downgrade()) {
-            state
-        } else {
-            return Task::ready(Err(anyhow!("project not added")));
-        };
-
-        let worktree_db_ids = project
-            .read(cx)
-            .worktrees(cx)
-            .filter_map(|worktree| {
-                let worktree_id = worktree.read(cx).id();
-                project_state.db_id_for_worktree_id(worktree_id)
-            })
-            .collect::<Vec<_>>();
-
+        let index = self.index_project(project.clone(), cx);
         let embedding_provider = self.embedding_provider.clone();
         let db_path = self.db.path().clone();
         let fs = self.fs.clone();
         cx.spawn(|this, mut cx| async move {
+            index.await?;
+
             let t0 = Instant::now();
             let database =
                 VectorDatabase::new(fs.clone(), db_path.clone(), cx.background()).await?;
@@ -681,6 +642,24 @@ impl SemanticIndex {
                 t0.elapsed().as_millis()
             );
 
+            let worktree_db_ids = this.read_with(&cx, |this, _| {
+                let project_state = this
+                    .projects
+                    .get(&project.downgrade())
+                    .ok_or_else(|| anyhow!("project was not indexed"))?;
+                let worktree_db_ids = project_state
+                    .worktrees
+                    .values()
+                    .filter_map(|worktree| {
+                        if let WorktreeState::Registered(worktree) = worktree {
+                            Some(worktree.db_id)
+                        } else {
+                            None
+                        }
+                    })
+                    .collect::<Vec<i64>>();
+                anyhow::Ok(worktree_db_ids)
+            })?;
             let file_ids = database
                 .retrieve_included_file_ids(&worktree_db_ids, &includes, &excludes)
                 .await?;
@@ -729,13 +708,13 @@ impl SemanticIndex {
             }
 
             let ids = results.into_iter().map(|(id, _)| id).collect::<Vec<i64>>();
-            let documents = database.get_documents_by_ids(ids.as_slice()).await?;
+            let spans = database.spans_for_ids(ids.as_slice()).await?;
 
             let mut tasks = Vec::new();
             let mut ranges = Vec::new();
             let weak_project = project.downgrade();
             project.update(&mut cx, |project, cx| {
-                for (worktree_db_id, file_path, byte_range) in documents {
+                for (worktree_db_id, file_path, byte_range) in spans {
                     let project_state =
                         if let Some(state) = this.read(cx).projects.get(&weak_project) {
                             state
@@ -764,7 +743,9 @@ impl SemanticIndex {
                 .filter_map(|(buffer, range)| {
                     let buffer = buffer.log_err()?;
                     let range = buffer.read_with(&cx, |buffer, _| {
-                        buffer.anchor_before(range.start)..buffer.anchor_after(range.end)
+                        let start = buffer.clip_offset(range.start, Bias::Left);
+                        let end = buffer.clip_offset(range.end, Bias::Right);
+                        buffer.anchor_before(start)..buffer.anchor_after(end)
                     });
                     Some(SearchResult { buffer, range })
                 })
@@ -772,95 +753,173 @@ impl SemanticIndex {
         })
     }
 
-    async fn reindex_changed_paths(
-        this: ModelHandle<SemanticIndex>,
+    pub fn index_project(
+        &mut self,
         project: ModelHandle<Project>,
-        last_changed_before: Option<Instant>,
-        cx: &mut AsyncAppContext,
-        embeddings_for_digest: Arc<HashMap<DocumentDigest, Embedding>>,
-    ) {
-        let mut pending_files = Vec::new();
-        let mut files_to_delete = Vec::new();
-        let (db, language_registry, parsing_files_tx) = this.update(cx, |this, cx| {
-            if let Some(project_state) = this.projects.get_mut(&project.downgrade()) {
-                let outstanding_job_count_tx = &project_state.outstanding_job_count_tx;
-                let db_ids = &project_state.worktree_db_ids;
-                let mut worktree: Option<ModelHandle<Worktree>> = None;
-
-                project_state.changed_paths.retain(|path, info| {
-                    if let Some(last_changed_before) = last_changed_before {
-                        if info.changed_at > last_changed_before {
-                            return true;
-                        }
-                    }
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>> {
+        if !self.projects.contains_key(&project.downgrade()) {
+            log::trace!("Registering Project for Semantic Index");
+
+            let subscription = cx.subscribe(&project, |this, project, event, cx| match event {
+                project::Event::WorktreeAdded | project::Event::WorktreeRemoved(_) => {
+                    this.project_worktrees_changed(project.clone(), cx);
+                }
+                project::Event::WorktreeUpdatedEntries(worktree_id, changes) => {
+                    this.project_entries_changed(project, *worktree_id, changes.clone(), cx);
+                }
+                _ => {}
+            });
+            self.projects
+                .insert(project.downgrade(), ProjectState::new(subscription));
+            self.project_worktrees_changed(project.clone(), cx);
+        }
+        let project_state = &self.projects[&project.downgrade()];
+        let mut pending_file_count_rx = project_state.pending_file_count_rx.clone();
+
+        let db = self.db.clone();
+        let language_registry = self.language_registry.clone();
+        let parsing_files_tx = self.parsing_files_tx.clone();
+        let worktree_registration = self.wait_for_worktree_registration(&project, cx);
 
-                    if worktree
-                        .as_ref()
-                        .map_or(true, |tree| tree.read(cx).id() != path.worktree_id)
-                    {
-                        worktree = project.read(cx).worktree_for_id(path.worktree_id, cx);
+        cx.spawn(|this, mut cx| async move {
+            worktree_registration.await?;
+
+            let mut pending_files = Vec::new();
+            let mut files_to_delete = Vec::new();
+            this.update(&mut cx, |this, cx| {
+                let project_state = this
+                    .projects
+                    .get_mut(&project.downgrade())
+                    .ok_or_else(|| anyhow!("project was dropped"))?;
+                let pending_file_count_tx = &project_state.pending_file_count_tx;
+
+                project_state
+                    .worktrees
+                    .retain(|worktree_id, worktree_state| {
+                        let worktree = if let Some(worktree) =
+                            project.read(cx).worktree_for_id(*worktree_id, cx)
+                        {
+                            worktree
+                        } else {
+                            return false;
+                        };
+                        let worktree_state =
+                            if let WorktreeState::Registered(worktree_state) = worktree_state {
+                                worktree_state
+                            } else {
+                                return true;
+                            };
+
+                        worktree_state.changed_paths.retain(|path, info| {
+                            if info.is_deleted {
+                                files_to_delete.push((worktree_state.db_id, path.clone()));
+                            } else {
+                                let absolute_path = worktree.read(cx).absolutize(path);
+                                let job_handle = JobHandle::new(pending_file_count_tx);
+                                pending_files.push(PendingFile {
+                                    absolute_path,
+                                    relative_path: path.clone(),
+                                    language: None,
+                                    job_handle,
+                                    modified_time: info.mtime,
+                                    worktree_db_id: worktree_state.db_id,
+                                });
+                            }
+
+                            false
+                        });
+                        true
+                    });
+
+                anyhow::Ok(())
+            })?;
+
+            cx.background()
+                .spawn(async move {
+                    for (worktree_db_id, path) in files_to_delete {
+                        db.delete_file(worktree_db_id, path).await.log_err();
                     }
-                    let Some(worktree) = &worktree else {
-                        return false;
-                    };
 
-                    let Some(worktree_db_id) = db_ids
-                        .iter()
-                        .find_map(|entry| (entry.0 == path.worktree_id).then_some(entry.1))
-                    else {
-                        return false;
+                    let embeddings_for_digest = {
+                        let mut files = HashMap::default();
+                        for pending_file in &pending_files {
+                            files
+                                .entry(pending_file.worktree_db_id)
+                                .or_insert(Vec::new())
+                                .push(pending_file.relative_path.clone());
+                        }
+                        Arc::new(
+                            db.embeddings_for_files(files)
+                                .await
+                                .log_err()
+                                .unwrap_or_default(),
+                        )
                     };
 
-                    if info.is_deleted {
-                        files_to_delete.push((worktree_db_id, path.path.to_path_buf()));
-                    } else {
-                        let absolute_path = worktree.read(cx).absolutize(&path.path);
-                        let job_handle = JobHandle::new(&outstanding_job_count_tx);
-                        pending_files.push(PendingFile {
-                            absolute_path,
-                            relative_path: path.path.to_path_buf(),
-                            language: None,
-                            job_handle,
-                            modified_time: info.mtime,
-                            worktree_db_id,
-                        });
+                    for mut pending_file in pending_files {
+                        if let Ok(language) = language_registry
+                            .language_for_file(&pending_file.relative_path, None)
+                            .await
+                        {
+                            if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
+                                && &language.name().as_ref() != &"Markdown"
+                                && language
+                                    .grammar()
+                                    .and_then(|grammar| grammar.embedding_config.as_ref())
+                                    .is_none()
+                            {
+                                continue;
+                            }
+                            pending_file.language = Some(language);
+                        }
+                        parsing_files_tx
+                            .try_send((embeddings_for_digest.clone(), pending_file))
+                            .ok();
                     }
 
-                    false
-                });
-            }
+                    // Wait until we're done indexing.
+                    while let Some(count) = pending_file_count_rx.next().await {
+                        if count == 0 {
+                            break;
+                        }
+                    }
+                })
+                .await;
 
-            (
-                this.db.clone(),
-                this.language_registry.clone(),
-                this.parsing_files_tx.clone(),
-            )
-        });
+            Ok(())
+        })
+    }
 
-        for (worktree_db_id, path) in files_to_delete {
-            db.delete_file(worktree_db_id, path).await.log_err();
-        }
+    fn wait_for_worktree_registration(
+        &self,
+        project: &ModelHandle<Project>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>> {
+        let project = project.downgrade();
+        cx.spawn_weak(|this, cx| async move {
+            loop {
+                let mut pending_worktrees = Vec::new();
+                this.upgrade(&cx)
+                    .ok_or_else(|| anyhow!("semantic index dropped"))?
+                    .read_with(&cx, |this, _| {
+                        if let Some(project) = this.projects.get(&project) {
+                            for worktree in project.worktrees.values() {
+                                if let WorktreeState::Registering(worktree) = worktree {
+                                    pending_worktrees.push(worktree.done());
+                                }
+                            }
+                        }
+                    });
 
-        for mut pending_file in pending_files {
-            if let Ok(language) = language_registry
-                .language_for_file(&pending_file.relative_path, None)
-                .await
-            {
-                if !PARSEABLE_ENTIRE_FILE_TYPES.contains(&language.name().as_ref())
-                    && &language.name().as_ref() != &"Markdown"
-                    && language
-                        .grammar()
-                        .and_then(|grammar| grammar.embedding_config.as_ref())
-                        .is_none()
-                {
-                    continue;
+                if pending_worktrees.is_empty() {
+                    break;
+                } else {
+                    future::join_all(pending_worktrees).await;
                 }
-                pending_file.language = Some(language);
             }
-            parsing_files_tx
-                .try_send((embeddings_for_digest.clone(), pending_file))
-                .ok();
-        }
+            Ok(())
+        })
     }
 }
 

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{
     embedding::{DummyEmbeddings, Embedding, EmbeddingProvider},
     embedding_queue::EmbeddingQueue,
-    parsing::{subtract_ranges, CodeContextRetriever, Document, DocumentDigest},
+    parsing::{subtract_ranges, CodeContextRetriever, Span, SpanDigest},
     semantic_index_settings::SemanticIndexSettings,
     FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
 };
@@ -87,34 +87,24 @@ async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestApp
 
     let project = Project::test(fs.clone(), ["/the-root".as_ref()], cx).await;
 
-    let _ = semantic_index
-        .update(cx, |store, cx| {
-            store.initialize_project(project.clone(), cx)
-        })
-        .await;
-
-    let (file_count, outstanding_file_count) = semantic_index
-        .update(cx, |store, cx| store.index_project(project.clone(), cx))
-        .await
-        .unwrap();
-    assert_eq!(file_count, 3);
+    let search_results = semantic_index.update(cx, |store, cx| {
+        store.search_project(
+            project.clone(),
+            "aaaaaabbbbzz".to_string(),
+            5,
+            vec![],
+            vec![],
+            cx,
+        )
+    });
+    let pending_file_count =
+        semantic_index.read_with(cx, |index, _| index.pending_file_count(&project).unwrap());
+    deterministic.run_until_parked();
+    assert_eq!(*pending_file_count.borrow(), 3);
     deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
-    assert_eq!(*outstanding_file_count.borrow(), 0);
-
-    let search_results = semantic_index
-        .update(cx, |store, cx| {
-            store.search_project(
-                project.clone(),
-                "aaaaaabbbbzz".to_string(),
-                5,
-                vec![],
-                vec![],
-                cx,
-            )
-        })
-        .await
-        .unwrap();
+    assert_eq!(*pending_file_count.borrow(), 0);
 
+    let search_results = search_results.await.unwrap();
     assert_search_results(
         &search_results,
         &[
@@ -191,14 +181,12 @@ async fn test_semantic_index(deterministic: Arc<Deterministic>, cx: &mut TestApp
     deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
 
     let prev_embedding_count = embedding_provider.embedding_count();
-    let (file_count, outstanding_file_count) = semantic_index
-        .update(cx, |store, cx| store.index_project(project.clone(), cx))
-        .await
-        .unwrap();
-    assert_eq!(file_count, 1);
-
+    let index = semantic_index.update(cx, |store, cx| store.index_project(project.clone(), cx));
+    deterministic.run_until_parked();
+    assert_eq!(*pending_file_count.borrow(), 1);
     deterministic.advance_clock(EMBEDDING_QUEUE_FLUSH_TIMEOUT);
-    assert_eq!(*outstanding_file_count.borrow(), 0);
+    assert_eq!(*pending_file_count.borrow(), 0);
+    index.await.unwrap();
 
     assert_eq!(
         embedding_provider.embedding_count() - prev_embedding_count,
@@ -214,17 +202,17 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
     let files = (1..=3)
         .map(|file_ix| FileToEmbed {
             worktree_id: 5,
-            path: format!("path-{file_ix}").into(),
+            path: Path::new(&format!("path-{file_ix}")).into(),
             mtime: SystemTime::now(),
-            documents: (0..rng.gen_range(4..22))
+            spans: (0..rng.gen_range(4..22))
                 .map(|document_ix| {
                     let content_len = rng.gen_range(10..100);
                     let content = RandomCharIter::new(&mut rng)
                         .with_simple_text()
                         .take(content_len)
                         .collect::<String>();
-                    let digest = DocumentDigest::from(content.as_str());
-                    Document {
+                    let digest = SpanDigest::from(content.as_str());
+                    Span {
                         range: 0..10,
                         embedding: None,
                         name: format!("document {document_ix}"),
@@ -257,7 +245,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
         .iter()
         .map(|file| {
             let mut file = file.clone();
-            for doc in &mut file.documents {
+            for doc in &mut file.spans {
                 doc.embedding = Some(embedding_provider.embed_sync(doc.content.as_ref()));
             }
             file
@@ -449,7 +437,7 @@ async fn test_code_context_retrieval_json() {
 }
 
 fn assert_documents_eq(
-    documents: &[Document],
+    documents: &[Span],
     expected_contents_and_start_offsets: &[(String, usize)],
 ) {
     assert_eq!(