Return an error from project index tool when embedding query fails (#11264)

Max Brunsfeld created

Previously, a failure to embed the search query (due to a rate limit
error) would appear the same as if there were no results.

* Avoid repeatedly embedding the search query for each worktree
* Unify tasks for searching all worktree

Release Notes:

- N/A

Change summary

crates/assistant2/src/tools/project_index.rs |   5 
crates/semantic_index/examples/index.rs      |   5 
crates/semantic_index/src/chunking.rs        |   9 
crates/semantic_index/src/semantic_index.rs  | 250 +++++++++++----------
4 files changed, 137 insertions(+), 132 deletions(-)

Detailed changes

crates/assistant2/src/tools/project_index.rs 🔗

@@ -140,10 +140,9 @@ impl LanguageModelTool for ProjectIndexTool {
 
     fn execute(&self, query: &Self::Input, cx: &mut WindowContext) -> Task<Result<Self::Output>> {
         let project_index = self.project_index.read(cx);
-
         let status = project_index.status();
         let results = project_index.search(
-            query.query.as_str(),
+            query.query.clone(),
             query.limit.unwrap_or(DEFAULT_SEARCH_LIMIT),
             cx,
         );
@@ -151,7 +150,7 @@ impl LanguageModelTool for ProjectIndexTool {
         let fs = self.fs.clone();
 
         cx.spawn(|cx| async move {
-            let results = results.await;
+            let results = results.await?;
 
             let excerpts = results.into_iter().map(|result| {
                 let abs_path = result

crates/semantic_index/examples/index.rs 🔗

@@ -92,10 +92,11 @@ fn main() {
                 .update(|cx| {
                     let project_index = project_index.read(cx);
                     let query = "converting an anchor to a point";
-                    project_index.search(query, 4, cx)
+                    project_index.search(query.into(), 4, cx)
                 })
                 .unwrap()
-                .await;
+                .await
+                .unwrap();
 
             for search_result in results {
                 let path = search_result.path.clone();

crates/semantic_index/src/chunking.rs 🔗

@@ -98,12 +98,9 @@ fn chunk_lines(text: &str) -> Vec<Chunk> {
 
     chunk_ranges
         .into_iter()
-        .map(|range| {
-            let mut hasher = Sha256::new();
-            hasher.update(&text[range.clone()]);
-            let mut digest = [0u8; 32];
-            digest.copy_from_slice(hasher.finalize().as_slice());
-            Chunk { range, digest }
+        .map(|range| Chunk {
+            digest: Sha256::digest(&text[range.clone()]).into(),
+            range,
         })
         .collect()
 }

crates/semantic_index/src/semantic_index.rs 🔗

@@ -15,7 +15,7 @@ use gpui::{
 use heed::types::{SerdeBincode, Str};
 use language::LanguageRegistry;
 use parking_lot::Mutex;
-use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree};
+use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
 use serde::{Deserialize, Serialize};
 use smol::channel;
 use std::{
@@ -156,6 +156,10 @@ impl ProjectIndex {
         self.last_status
     }
 
+    pub fn project(&self) -> WeakModel<Project> {
+        self.project.clone()
+    }
+
     fn handle_project_event(
         &mut self,
         _: Model<Project>,
@@ -259,30 +263,126 @@ impl ProjectIndex {
         }
     }
 
-    pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
-        let mut worktree_searches = Vec::new();
+    pub fn search(
+        &self,
+        query: String,
+        limit: usize,
+        cx: &AppContext,
+    ) -> Task<Result<Vec<SearchResult>>> {
+        let (chunks_tx, chunks_rx) = channel::bounded(1024);
+        let mut worktree_scan_tasks = Vec::new();
         for worktree_index in self.worktree_indices.values() {
             if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
-                worktree_searches
-                    .push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
+                let chunks_tx = chunks_tx.clone();
+                index.read_with(cx, |index, cx| {
+                    let worktree_id = index.worktree.read(cx).id();
+                    let db_connection = index.db_connection.clone();
+                    let db = index.db;
+                    worktree_scan_tasks.push(cx.background_executor().spawn({
+                        async move {
+                            let txn = db_connection
+                                .read_txn()
+                                .context("failed to create read transaction")?;
+                            let db_entries = db.iter(&txn).context("failed to iterate database")?;
+                            for db_entry in db_entries {
+                                let (_key, db_embedded_file) = db_entry?;
+                                for chunk in db_embedded_file.chunks {
+                                    chunks_tx
+                                        .send((worktree_id, db_embedded_file.path.clone(), chunk))
+                                        .await?;
+                                }
+                            }
+                            anyhow::Ok(())
+                        }
+                    }));
+                })
             }
         }
+        drop(chunks_tx);
 
-        cx.spawn(|_| async move {
-            let mut results = Vec::new();
-            let worktree_searches = futures::future::join_all(worktree_searches).await;
+        let project = self.project.clone();
+        let embedding_provider = self.embedding_provider.clone();
+        cx.spawn(|cx| async move {
+            #[cfg(debug_assertions)]
+            let embedding_query_start = std::time::Instant::now();
+            log::info!("Searching for {query}");
 
-            for worktree_search_results in worktree_searches {
-                if let Some(worktree_search_results) = worktree_search_results.log_err() {
-                    results.extend(worktree_search_results);
-                }
+            let query_embeddings = embedding_provider
+                .embed(&[TextToEmbed::new(&query)])
+                .await?;
+            let query_embedding = query_embeddings
+                .into_iter()
+                .next()
+                .ok_or_else(|| anyhow!("no embedding for query"))?;
+
+            let mut results_by_worker = Vec::new();
+            for _ in 0..cx.background_executor().num_cpus() {
+                results_by_worker.push(Vec::<WorktreeSearchResult>::new());
             }
 
-            results
-                .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
-            results.truncate(limit);
+            #[cfg(debug_assertions)]
+            let search_start = std::time::Instant::now();
 
-            results
+            cx.background_executor()
+                .scoped(|cx| {
+                    for results in results_by_worker.iter_mut() {
+                        cx.spawn(async {
+                            while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
+                                let score = chunk.embedding.similarity(&query_embedding);
+                                let ix = match results.binary_search_by(|probe| {
+                                    score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
+                                }) {
+                                    Ok(ix) | Err(ix) => ix,
+                                };
+                                results.insert(
+                                    ix,
+                                    WorktreeSearchResult {
+                                        worktree_id,
+                                        path: path.clone(),
+                                        range: chunk.chunk.range.clone(),
+                                        score,
+                                    },
+                                );
+                                results.truncate(limit);
+                            }
+                        });
+                    }
+                })
+                .await;
+
+            futures::future::try_join_all(worktree_scan_tasks).await?;
+
+            project.read_with(&cx, |project, cx| {
+                let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
+                for worker_results in results_by_worker {
+                    search_results.extend(worker_results.into_iter().filter_map(|result| {
+                        Some(SearchResult {
+                            worktree: project.worktree_for_id(result.worktree_id, cx)?,
+                            path: result.path,
+                            range: result.range,
+                            score: result.score,
+                        })
+                    }));
+                }
+                search_results.sort_unstable_by(|a, b| {
+                    b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
+                });
+                search_results.truncate(limit);
+
+                #[cfg(debug_assertions)]
+                {
+                    let search_elapsed = search_start.elapsed();
+                    log::debug!(
+                        "searched {} entries in {:?}",
+                        search_results.len(),
+                        search_elapsed
+                    );
+                    let embedding_query_elapsed = embedding_query_start.elapsed();
+                    log::debug!("embedding query took {:?}", embedding_query_elapsed);
+                }
+
+                search_results
+            })
         })
     }
 
@@ -327,6 +427,13 @@ pub struct SearchResult {
     pub score: f32,
 }
 
+pub struct WorktreeSearchResult {
+    pub worktree_id: WorktreeId,
+    pub path: Arc<Path>,
+    pub range: Range<usize>,
+    pub score: f32,
+}
+
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum Status {
     Idle,
@@ -764,107 +871,6 @@ impl WorktreeIndex {
         })
     }
 
-    fn search(
-        &self,
-        query: &str,
-        limit: usize,
-        cx: &AppContext,
-    ) -> Task<Result<Vec<SearchResult>>> {
-        let (chunks_tx, chunks_rx) = channel::bounded(1024);
-
-        let db_connection = self.db_connection.clone();
-        let db = self.db;
-        let scan_chunks = cx.background_executor().spawn({
-            async move {
-                let txn = db_connection
-                    .read_txn()
-                    .context("failed to create read transaction")?;
-                let db_entries = db.iter(&txn).context("failed to iterate database")?;
-                for db_entry in db_entries {
-                    let (_key, db_embedded_file) = db_entry?;
-                    for chunk in db_embedded_file.chunks {
-                        chunks_tx
-                            .send((db_embedded_file.path.clone(), chunk))
-                            .await?;
-                    }
-                }
-                anyhow::Ok(())
-            }
-        });
-
-        let query = query.to_string();
-        let embedding_provider = self.embedding_provider.clone();
-        let worktree = self.worktree.clone();
-        cx.spawn(|cx| async move {
-            #[cfg(debug_assertions)]
-            let embedding_query_start = std::time::Instant::now();
-            log::info!("Searching for {query}");
-
-            let mut query_embeddings = embedding_provider
-                .embed(&[TextToEmbed::new(&query)])
-                .await?;
-            let query_embedding = query_embeddings
-                .pop()
-                .ok_or_else(|| anyhow!("no embedding for query"))?;
-            let mut workers = Vec::new();
-            for _ in 0..cx.background_executor().num_cpus() {
-                workers.push(Vec::<SearchResult>::new());
-            }
-
-            #[cfg(debug_assertions)]
-            let search_start = std::time::Instant::now();
-
-            cx.background_executor()
-                .scoped(|cx| {
-                    for worker_results in workers.iter_mut() {
-                        cx.spawn(async {
-                            while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
-                                let score = embedded_chunk.embedding.similarity(&query_embedding);
-                                let ix = match worker_results.binary_search_by(|probe| {
-                                    score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
-                                }) {
-                                    Ok(ix) | Err(ix) => ix,
-                                };
-                                worker_results.insert(
-                                    ix,
-                                    SearchResult {
-                                        worktree: worktree.clone(),
-                                        path: path.clone(),
-                                        range: embedded_chunk.chunk.range.clone(),
-                                        score,
-                                    },
-                                );
-                                worker_results.truncate(limit);
-                            }
-                        });
-                    }
-                })
-                .await;
-            scan_chunks.await?;
-
-            let mut search_results = Vec::with_capacity(workers.len() * limit);
-            for worker_results in workers {
-                search_results.extend(worker_results);
-            }
-            search_results
-                .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
-            search_results.truncate(limit);
-            #[cfg(debug_assertions)]
-            {
-                let search_elapsed = search_start.elapsed();
-                log::debug!(
-                    "searched {} entries in {:?}",
-                    search_results.len(),
-                    search_elapsed
-                );
-                let embedding_query_elapsed = embedding_query_start.elapsed();
-                log::debug!("embedding query took {:?}", embedding_query_elapsed);
-            }
-
-            Ok(search_results)
-        })
-    }
-
     fn debug(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
         let connection = self.db_connection.clone();
         let db = self.db;
@@ -1093,9 +1099,10 @@ mod tests {
             .update(|cx| {
                 let project_index = project_index.read(cx);
                 let query = "garbage in, garbage out";
-                project_index.search(query, 4, cx)
+                project_index.search(query.into(), 4, cx)
             })
-            .await;
+            .await
+            .unwrap();
 
         assert!(results.len() > 1, "should have found some results");
 
@@ -1112,9 +1119,10 @@ mod tests {
         let content = cx
             .update(|cx| {
                 let worktree = search_result.worktree.read(cx);
-                let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
+                let entry_abs_path = worktree.abs_path().join(&search_result.path);
                 let fs = project.read(cx).fs().clone();
-                cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
+                cx.background_executor()
+                    .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
             })
             .await;