Fix some semantic index issues (#11216)

Max Brunsfeld , Marshall , Nathan , Kyle , and Kyle Kelley created

* [x] Fixed an issue where embeddings would be assigned incorrectly to
files if a subset of embedding batches failed
* [x] Added a command to debug which paths are present in the semantic
index
* [x] Determine why so many paths are often missing from the semantic
index
* we erroring out if an embedding batch contained multiple texts that
were the same, which can happen if a worktree contains multiple copies
of the same text (e.g. a license).

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Nathan <nathan@zed.dev>
Co-authored-by: Kyle <kylek@zed.dev>
Co-authored-by: Kyle Kelley <rgbkrk@gmail.com>

Change summary

crates/assistant2/examples/chat_with_functions.rs |   2 
crates/assistant2/src/assistant2.rs               |  26 +
crates/semantic_index/src/embedding/cloud.rs      |   9 
crates/semantic_index/src/semantic_index.rs       | 261 +++++++++++++---
4 files changed, 237 insertions(+), 61 deletions(-)

Detailed changes

crates/assistant2/examples/chat_with_functions.rs 🔗

@@ -365,7 +365,7 @@ impl Example {
     ) -> Self {
         Self {
             assistant_panel: cx.new_view(|cx| {
-                AssistantPanel::new(language_registry, tool_registry, user_store, cx)
+                AssistantPanel::new(language_registry, tool_registry, user_store, None, cx)
             }),
         }
     }

crates/assistant2/src/assistant2.rs 🔗

@@ -19,7 +19,7 @@ use gpui::{
 use language::{language_settings::SoftWrap, LanguageRegistry};
 use open_ai::{FunctionContent, ToolCall, ToolCallContent};
 use rich_text::RichText;
-use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
+use semantic_index::{CloudEmbeddingProvider, ProjectIndex, SemanticIndex};
 use serde::Deserialize;
 use settings::Settings;
 use std::sync::Arc;
@@ -51,7 +51,7 @@ pub enum SubmitMode {
     Codebase,
 }
 
-gpui::actions!(assistant2, [Cancel, ToggleFocus]);
+gpui::actions!(assistant2, [Cancel, ToggleFocus, DebugProjectIndex]);
 gpui::impl_actions!(assistant2, [Submit]);
 
 pub fn init(client: Arc<Client>, cx: &mut AppContext) {
@@ -131,7 +131,13 @@ impl AssistantPanel {
 
                 let tool_registry = Arc::new(tool_registry);
 
-                Self::new(app_state.languages.clone(), tool_registry, user_store, cx)
+                Self::new(
+                    app_state.languages.clone(),
+                    tool_registry,
+                    user_store,
+                    Some(project_index),
+                    cx,
+                )
             })
         })
     }
@@ -140,6 +146,7 @@ impl AssistantPanel {
         language_registry: Arc<LanguageRegistry>,
         tool_registry: Arc<ToolRegistry>,
         user_store: Model<UserStore>,
+        project_index: Option<Model<ProjectIndex>>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
         let chat = cx.new_view(|cx| {
@@ -147,6 +154,7 @@ impl AssistantPanel {
                 language_registry.clone(),
                 tool_registry.clone(),
                 user_store,
+                project_index,
                 cx,
             )
         });
@@ -225,6 +233,7 @@ struct AssistantChat {
     collapsed_messages: HashMap<MessageId, bool>,
     pending_completion: Option<Task<()>>,
     tool_registry: Arc<ToolRegistry>,
+    project_index: Option<Model<ProjectIndex>>,
 }
 
 impl AssistantChat {
@@ -232,6 +241,7 @@ impl AssistantChat {
         language_registry: Arc<LanguageRegistry>,
         tool_registry: Arc<ToolRegistry>,
         user_store: Model<UserStore>,
+        project_index: Option<Model<ProjectIndex>>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
         let model = CompletionProvider::get(cx).default_model();
@@ -258,6 +268,7 @@ impl AssistantChat {
             list_state,
             user_store,
             language_registry,
+            project_index,
             next_message_id: MessageId(0),
             collapsed_messages: HashMap::default(),
             pending_completion: None,
@@ -342,6 +353,14 @@ impl AssistantChat {
         self.pending_completion.is_none()
     }
 
+    fn debug_project_index(&mut self, _: &DebugProjectIndex, cx: &mut ViewContext<Self>) {
+        if let Some(index) = &self.project_index {
+            index.update(cx, |project_index, cx| {
+                project_index.debug(cx).detach_and_log_err(cx)
+            });
+        }
+    }
+
     async fn request_completion(
         this: WeakView<Self>,
         mode: SubmitMode,
@@ -686,6 +705,7 @@ impl Render for AssistantChat {
             .key_context("AssistantChat")
             .on_action(cx.listener(Self::submit))
             .on_action(cx.listener(Self::cancel))
+            .on_action(cx.listener(Self::debug_project_index))
             .text_color(Color::Default.color(cx))
             .child(list(self.list_state.clone()).flex_1())
             .child(Composer::new(

crates/semantic_index/src/embedding/cloud.rs 🔗

@@ -72,10 +72,11 @@ impl EmbeddingProvider for CloudEmbeddingProvider {
             texts
                 .iter()
                 .map(|to_embed| {
-                    let dimensions = embeddings.remove(&to_embed.digest).with_context(|| {
-                        format!("server did not return an embedding for {:?}", to_embed)
-                    })?;
-                    Ok(Embedding::new(dimensions))
+                    let embedding =
+                        embeddings.get(&to_embed.digest).cloned().with_context(|| {
+                            format!("server did not return an embedding for {:?}", to_embed)
+                        })?;
+                    Ok(Embedding::new(embedding))
                 })
                 .collect()
         }

crates/semantic_index/src/semantic_index.rs 🔗

@@ -21,6 +21,7 @@ use smol::channel;
 use std::{
     cmp::Ordering,
     future::Future,
+    iter,
     num::NonZeroUsize,
     ops::Range,
     path::{Path, PathBuf},
@@ -295,6 +296,28 @@ impl ProjectIndex {
         }
         Ok(result)
     }
+
+    pub fn debug(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+        let indices = self
+            .worktree_indices
+            .values()
+            .filter_map(|worktree_index| {
+                if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
+                    Some(index.clone())
+                } else {
+                    None
+                }
+            })
+            .collect::<Vec<_>>();
+
+        cx.spawn(|_, mut cx| async move {
+            eprintln!("semantic index contents:");
+            for index in indices {
+                index.update(&mut cx, |index, cx| index.debug(cx))?.await?
+            }
+            Ok(())
+        })
+    }
 }
 
 pub struct SearchResult {
@@ -419,7 +442,7 @@ impl WorktreeIndex {
         let worktree_abs_path = worktree.abs_path().clone();
         let scan = self.scan_entries(worktree.clone(), cx);
         let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
-        let embed = self.embed_files(chunk.files, cx);
+        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
         let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
         async move {
             futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
@@ -436,7 +459,7 @@ impl WorktreeIndex {
         let worktree_abs_path = worktree.abs_path().clone();
         let scan = self.scan_updated_entries(worktree, updated_entries.clone(), cx);
         let chunk = self.chunk_files(worktree_abs_path, scan.updated_entries, cx);
-        let embed = self.embed_files(chunk.files, cx);
+        let embed = Self::embed_files(self.embedding_provider.clone(), chunk.files, cx);
         let persist = self.persist_embeddings(scan.deleted_entry_ranges, embed.files, cx);
         async move {
             futures::try_join!(scan.task, chunk.task, embed.task, persist)?;
@@ -500,7 +523,7 @@ impl WorktreeIndex {
                 }
 
                 if entry.mtime != saved_mtime {
-                    let handle = entries_being_indexed.insert(&entry);
+                    let handle = entries_being_indexed.insert(entry.id);
                     updated_entries_tx.send((entry.clone(), handle)).await?;
                 }
             }
@@ -539,7 +562,7 @@ impl WorktreeIndex {
                     | project::PathChange::AddedOrUpdated => {
                         if let Some(entry) = worktree.entry_for_id(*entry_id) {
                             if entry.is_file() {
-                                let handle = entries_being_indexed.insert(&entry);
+                                let handle = entries_being_indexed.insert(entry.id);
                                 updated_entries_tx.send((entry.clone(), handle)).await?;
                             }
                         }
@@ -601,7 +624,8 @@ impl WorktreeIndex {
                                 let chunked_file = ChunkedFile {
                                     chunks: chunk_text(&text, grammar),
                                     handle,
-                                    entry,
+                                    path: entry.path,
+                                    mtime: entry.mtime,
                                     text,
                                 };
 
@@ -623,11 +647,11 @@ impl WorktreeIndex {
     }
 
     fn embed_files(
-        &self,
+        embedding_provider: Arc<dyn EmbeddingProvider>,
         chunked_files: channel::Receiver<ChunkedFile>,
         cx: &AppContext,
     ) -> EmbedFiles {
-        let embedding_provider = self.embedding_provider.clone();
+        let embedding_provider = embedding_provider.clone();
         let (embedded_files_tx, embedded_files_rx) = channel::bounded(512);
         let task = cx.background_executor().spawn(async move {
             let mut chunked_file_batches =
@@ -635,9 +659,10 @@ impl WorktreeIndex {
             while let Some(chunked_files) = chunked_file_batches.next().await {
                 // View the batch of files as a vec of chunks
                 // Flatten out to a vec of chunks that we can subdivide into batch sized pieces
-                // Once those are done, reassemble it back into which files they belong to
+                // Once those are done, reassemble them back into the files in which they belong
+                // If any embeddings fail for a file, the entire file is discarded
 
-                let chunks = chunked_files
+                let chunks: Vec<TextToEmbed> = chunked_files
                     .iter()
                     .flat_map(|file| {
                         file.chunks.iter().map(|chunk| TextToEmbed {
@@ -647,36 +672,50 @@ impl WorktreeIndex {
                     })
                     .collect::<Vec<_>>();
 
-                let mut embeddings = Vec::new();
+                let mut embeddings: Vec<Option<Embedding>> = Vec::new();
                 for embedding_batch in chunks.chunks(embedding_provider.batch_size()) {
                     if let Some(batch_embeddings) =
                         embedding_provider.embed(embedding_batch).await.log_err()
                     {
-                        embeddings.extend_from_slice(&batch_embeddings);
+                        if batch_embeddings.len() == embedding_batch.len() {
+                            embeddings.extend(batch_embeddings.into_iter().map(Some));
+                            continue;
+                        }
+                        log::error!(
+                            "embedding provider returned unexpected embedding count {}, expected {}",
+                            batch_embeddings.len(), embedding_batch.len()
+                        );
                     }
+
+                    embeddings.extend(iter::repeat(None).take(embedding_batch.len()));
                 }
 
                 let mut embeddings = embeddings.into_iter();
                 for chunked_file in chunked_files {
-                    let chunk_embeddings = embeddings
-                        .by_ref()
-                        .take(chunked_file.chunks.len())
-                        .collect::<Vec<_>>();
-                    let embedded_chunks = chunked_file
-                        .chunks
-                        .into_iter()
-                        .zip(chunk_embeddings)
-                        .map(|(chunk, embedding)| EmbeddedChunk { chunk, embedding })
-                        .collect();
-                    let embedded_file = EmbeddedFile {
-                        path: chunked_file.entry.path.clone(),
-                        mtime: chunked_file.entry.mtime,
-                        chunks: embedded_chunks,
+                    let mut embedded_file = EmbeddedFile {
+                        path: chunked_file.path,
+                        mtime: chunked_file.mtime,
+                        chunks: Vec::new(),
                     };
 
-                    embedded_files_tx
-                        .send((embedded_file, chunked_file.handle))
-                        .await?;
+                    let mut embedded_all_chunks = true;
+                    for (chunk, embedding) in
+                        chunked_file.chunks.into_iter().zip(embeddings.by_ref())
+                    {
+                        if let Some(embedding) = embedding {
+                            embedded_file
+                                .chunks
+                                .push(EmbeddedChunk { chunk, embedding });
+                        } else {
+                            embedded_all_chunks = false;
+                        }
+                    }
+
+                    if embedded_all_chunks {
+                        embedded_files_tx
+                            .send((embedded_file, chunked_file.handle))
+                            .await?;
+                    }
                 }
             }
             Ok(())
@@ -826,6 +865,21 @@ impl WorktreeIndex {
         })
     }
 
+    fn debug(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+        let connection = self.db_connection.clone();
+        let db = self.db;
+        cx.background_executor().spawn(async move {
+            let tx = connection
+                .read_txn()
+                .context("failed to create read transaction")?;
+            for record in db.iter(&tx)? {
+                let (key, _) = record?;
+                eprintln!("{}", path_for_db_key(key));
+            }
+            Ok(())
+        })
+    }
+
     #[cfg(test)]
     fn path_count(&self) -> Result<u64> {
         let txn = self
@@ -848,7 +902,8 @@ struct ChunkFiles {
 }
 
 struct ChunkedFile {
-    pub entry: Entry,
+    pub path: Arc<Path>,
+    pub mtime: Option<SystemTime>,
     pub handle: IndexingEntryHandle,
     pub text: String,
     pub chunks: Vec<Chunk>,
@@ -872,11 +927,14 @@ struct EmbeddedChunk {
     embedding: Embedding,
 }
 
+/// The set of entries that are currently being indexed.
 struct IndexingEntrySet {
     entry_ids: Mutex<HashSet<ProjectEntryId>>,
     tx: channel::Sender<()>,
 }
 
+/// When dropped, removes the entry from the set of entries that are being indexed.
+#[derive(Clone)]
 struct IndexingEntryHandle {
     entry_id: ProjectEntryId,
     set: Weak<IndexingEntrySet>,
@@ -890,11 +948,11 @@ impl IndexingEntrySet {
         }
     }
 
-    fn insert(self: &Arc<Self>, entry: &project::Entry) -> IndexingEntryHandle {
-        self.entry_ids.lock().insert(entry.id);
+    fn insert(self: &Arc<Self>, entry_id: ProjectEntryId) -> IndexingEntryHandle {
+        self.entry_ids.lock().insert(entry_id);
         self.tx.send_blocking(()).ok();
         IndexingEntryHandle {
-            entry_id: entry.id,
+            entry_id,
             set: Arc::downgrade(self),
         }
     }
@@ -917,6 +975,10 @@ fn db_key_for_path(path: &Arc<Path>) -> String {
     path.to_string_lossy().replace('/', "\0")
 }
 
+fn path_for_db_key(key: &str) -> String {
+    key.replace('\0', "/")
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -939,7 +1001,22 @@ mod tests {
         });
     }
 
-    pub struct TestEmbeddingProvider;
+    pub struct TestEmbeddingProvider {
+        batch_size: usize,
+        compute_embedding: Box<dyn Fn(&str) -> Result<Embedding> + Send + Sync>,
+    }
+
+    impl TestEmbeddingProvider {
+        pub fn new(
+            batch_size: usize,
+            compute_embedding: impl 'static + Fn(&str) -> Result<Embedding> + Send + Sync,
+        ) -> Self {
+            return Self {
+                batch_size,
+                compute_embedding: Box::new(compute_embedding),
+            };
+        }
+    }
 
     impl EmbeddingProvider for TestEmbeddingProvider {
         fn embed<'a>(
@@ -948,29 +1025,13 @@ mod tests {
         ) -> BoxFuture<'a, Result<Vec<Embedding>>> {
             let embeddings = texts
                 .iter()
-                .map(|text| {
-                    let mut embedding = vec![0f32; 2];
-                    // if the text contains garbage, give it a 1 in the first dimension
-                    if text.text.contains("garbage in") {
-                        embedding[0] = 0.9;
-                    } else {
-                        embedding[0] = -0.9;
-                    }
-
-                    if text.text.contains("garbage out") {
-                        embedding[1] = 0.9;
-                    } else {
-                        embedding[1] = -0.9;
-                    }
-
-                    Embedding::new(embedding)
-                })
+                .map(|to_embed| (self.compute_embedding)(to_embed.text))
                 .collect();
-            future::ready(Ok(embeddings)).boxed()
+            future::ready(embeddings).boxed()
         }
 
         fn batch_size(&self) -> usize {
-            16
+            self.batch_size
         }
     }
 
@@ -984,7 +1045,23 @@ mod tests {
 
         let mut semantic_index = SemanticIndex::new(
             temp_dir.path().into(),
-            Arc::new(TestEmbeddingProvider),
+            Arc::new(TestEmbeddingProvider::new(16, |text| {
+                let mut embedding = vec![0f32; 2];
+                // if the text contains garbage, give it a 1 in the first dimension
+                if text.contains("garbage in") {
+                    embedding[0] = 0.9;
+                } else {
+                    embedding[0] = -0.9;
+                }
+
+                if text.contains("garbage out") {
+                    embedding[1] = 0.9;
+                } else {
+                    embedding[1] = -0.9;
+                }
+
+                Ok(Embedding::new(embedding))
+            })),
             &mut cx.to_async(),
         )
         .await
@@ -1046,4 +1123,82 @@ mod tests {
 
         assert!(content.contains("garbage in, garbage out"));
     }
+
+    #[gpui::test]
+    async fn test_embed_files(cx: &mut TestAppContext) {
+        cx.executor().allow_parking();
+
+        let provider = Arc::new(TestEmbeddingProvider::new(3, |text| {
+            if text.contains('g') {
+                Err(anyhow!("cannot embed text containing a 'g' character"))
+            } else {
+                Ok(Embedding::new(
+                    ('a'..'z')
+                        .map(|char| text.chars().filter(|c| *c == char).count() as f32)
+                        .collect(),
+                ))
+            }
+        }));
+
+        let (indexing_progress_tx, _) = channel::unbounded();
+        let indexing_entries = Arc::new(IndexingEntrySet::new(indexing_progress_tx));
+
+        let (chunked_files_tx, chunked_files_rx) = channel::unbounded::<ChunkedFile>();
+        chunked_files_tx
+            .send_blocking(ChunkedFile {
+                path: Path::new("test1.md").into(),
+                mtime: None,
+                handle: indexing_entries.insert(ProjectEntryId::from_proto(0)),
+                text: "abcdefghijklmnop".to_string(),
+                chunks: [0..4, 4..8, 8..12, 12..16]
+                    .into_iter()
+                    .map(|range| Chunk {
+                        range,
+                        digest: Default::default(),
+                    })
+                    .collect(),
+            })
+            .unwrap();
+        chunked_files_tx
+            .send_blocking(ChunkedFile {
+                path: Path::new("test2.md").into(),
+                mtime: None,
+                handle: indexing_entries.insert(ProjectEntryId::from_proto(1)),
+                text: "qrstuvwxyz".to_string(),
+                chunks: [0..4, 4..8, 8..10]
+                    .into_iter()
+                    .map(|range| Chunk {
+                        range,
+                        digest: Default::default(),
+                    })
+                    .collect(),
+            })
+            .unwrap();
+        chunked_files_tx.close();
+
+        let embed_files_task =
+            cx.update(|cx| WorktreeIndex::embed_files(provider.clone(), chunked_files_rx, cx));
+        embed_files_task.task.await.unwrap();
+
+        let mut embedded_files_rx = embed_files_task.files;
+        let mut embedded_files = Vec::new();
+        while let Some((embedded_file, _)) = embedded_files_rx.next().await {
+            embedded_files.push(embedded_file);
+        }
+
+        assert_eq!(embedded_files.len(), 1);
+        assert_eq!(embedded_files[0].path.as_ref(), Path::new("test2.md"));
+        assert_eq!(
+            embedded_files[0]
+                .chunks
+                .iter()
+                .map(|embedded_chunk| { embedded_chunk.embedding.clone() })
+                .collect::<Vec<Embedding>>(),
+            vec![
+                (provider.compute_embedding)("qrst").unwrap(),
+                (provider.compute_embedding)("uvwx").unwrap(),
+                (provider.compute_embedding)("yz").unwrap(),
+            ],
+        );
+    }
 }