@@ -15,7 +15,7 @@ use gpui::{
use heed::types::{SerdeBincode, Str};
use language::LanguageRegistry;
use parking_lot::Mutex;
-use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree};
+use project::{Entry, Project, ProjectEntryId, UpdatedEntriesSet, Worktree, WorktreeId};
use serde::{Deserialize, Serialize};
use smol::channel;
use std::{
@@ -156,6 +156,10 @@ impl ProjectIndex {
self.last_status
}
+ pub fn project(&self) -> WeakModel<Project> {
+ self.project.clone()
+ }
+
fn handle_project_event(
&mut self,
_: Model<Project>,
@@ -259,30 +263,126 @@ impl ProjectIndex {
}
}
- pub fn search(&self, query: &str, limit: usize, cx: &AppContext) -> Task<Vec<SearchResult>> {
- let mut worktree_searches = Vec::new();
+ pub fn search(
+ &self,
+ query: String,
+ limit: usize,
+ cx: &AppContext,
+ ) -> Task<Result<Vec<SearchResult>>> {
+ let (chunks_tx, chunks_rx) = channel::bounded(1024);
+ let mut worktree_scan_tasks = Vec::new();
for worktree_index in self.worktree_indices.values() {
if let WorktreeIndexHandle::Loaded { index, .. } = worktree_index {
- worktree_searches
- .push(index.read_with(cx, |index, cx| index.search(query, limit, cx)));
+ let chunks_tx = chunks_tx.clone();
+ index.read_with(cx, |index, cx| {
+ let worktree_id = index.worktree.read(cx).id();
+ let db_connection = index.db_connection.clone();
+ let db = index.db;
+ worktree_scan_tasks.push(cx.background_executor().spawn({
+ async move {
+ let txn = db_connection
+ .read_txn()
+ .context("failed to create read transaction")?;
+ let db_entries = db.iter(&txn).context("failed to iterate database")?;
+ for db_entry in db_entries {
+ let (_key, db_embedded_file) = db_entry?;
+ for chunk in db_embedded_file.chunks {
+ chunks_tx
+ .send((worktree_id, db_embedded_file.path.clone(), chunk))
+ .await?;
+ }
+ }
+ anyhow::Ok(())
+ }
+ }));
+ })
}
}
+ drop(chunks_tx);
- cx.spawn(|_| async move {
- let mut results = Vec::new();
- let worktree_searches = futures::future::join_all(worktree_searches).await;
+ let project = self.project.clone();
+ let embedding_provider = self.embedding_provider.clone();
+ cx.spawn(|cx| async move {
+ #[cfg(debug_assertions)]
+ let embedding_query_start = std::time::Instant::now();
+ log::info!("Searching for {query}");
- for worktree_search_results in worktree_searches {
- if let Some(worktree_search_results) = worktree_search_results.log_err() {
- results.extend(worktree_search_results);
- }
+ let query_embeddings = embedding_provider
+ .embed(&[TextToEmbed::new(&query)])
+ .await?;
+ let query_embedding = query_embeddings
+ .into_iter()
+ .next()
+ .ok_or_else(|| anyhow!("no embedding for query"))?;
+
+ let mut results_by_worker = Vec::new();
+ for _ in 0..cx.background_executor().num_cpus() {
+ results_by_worker.push(Vec::<WorktreeSearchResult>::new());
}
- results
- .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
- results.truncate(limit);
+ #[cfg(debug_assertions)]
+ let search_start = std::time::Instant::now();
- results
+ cx.background_executor()
+ .scoped(|cx| {
+ for results in results_by_worker.iter_mut() {
+ cx.spawn(async {
+ while let Ok((worktree_id, path, chunk)) = chunks_rx.recv().await {
+ let score = chunk.embedding.similarity(&query_embedding);
+ let ix = match results.binary_search_by(|probe| {
+ score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
+ }) {
+ Ok(ix) | Err(ix) => ix,
+ };
+ results.insert(
+ ix,
+ WorktreeSearchResult {
+ worktree_id,
+ path: path.clone(),
+ range: chunk.chunk.range.clone(),
+ score,
+ },
+ );
+ results.truncate(limit);
+ }
+ });
+ }
+ })
+ .await;
+
+ futures::future::try_join_all(worktree_scan_tasks).await?;
+
+ project.read_with(&cx, |project, cx| {
+ let mut search_results = Vec::with_capacity(results_by_worker.len() * limit);
+ for worker_results in results_by_worker {
+ search_results.extend(worker_results.into_iter().filter_map(|result| {
+ Some(SearchResult {
+ worktree: project.worktree_for_id(result.worktree_id, cx)?,
+ path: result.path,
+ range: result.range,
+ score: result.score,
+ })
+ }));
+ }
+ search_results.sort_unstable_by(|a, b| {
+ b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal)
+ });
+ search_results.truncate(limit);
+
+ #[cfg(debug_assertions)]
+ {
+ let search_elapsed = search_start.elapsed();
+ log::debug!(
+ "searched {} entries in {:?}",
+ search_results.len(),
+ search_elapsed
+ );
+ let embedding_query_elapsed = embedding_query_start.elapsed();
+ log::debug!("embedding query took {:?}", embedding_query_elapsed);
+ }
+
+ search_results
+ })
})
}
@@ -327,6 +427,13 @@ pub struct SearchResult {
pub score: f32,
}
+pub struct WorktreeSearchResult {
+ pub worktree_id: WorktreeId,
+ pub path: Arc<Path>,
+ pub range: Range<usize>,
+ pub score: f32,
+}
+
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Status {
Idle,
@@ -764,107 +871,6 @@ impl WorktreeIndex {
})
}
- fn search(
- &self,
- query: &str,
- limit: usize,
- cx: &AppContext,
- ) -> Task<Result<Vec<SearchResult>>> {
- let (chunks_tx, chunks_rx) = channel::bounded(1024);
-
- let db_connection = self.db_connection.clone();
- let db = self.db;
- let scan_chunks = cx.background_executor().spawn({
- async move {
- let txn = db_connection
- .read_txn()
- .context("failed to create read transaction")?;
- let db_entries = db.iter(&txn).context("failed to iterate database")?;
- for db_entry in db_entries {
- let (_key, db_embedded_file) = db_entry?;
- for chunk in db_embedded_file.chunks {
- chunks_tx
- .send((db_embedded_file.path.clone(), chunk))
- .await?;
- }
- }
- anyhow::Ok(())
- }
- });
-
- let query = query.to_string();
- let embedding_provider = self.embedding_provider.clone();
- let worktree = self.worktree.clone();
- cx.spawn(|cx| async move {
- #[cfg(debug_assertions)]
- let embedding_query_start = std::time::Instant::now();
- log::info!("Searching for {query}");
-
- let mut query_embeddings = embedding_provider
- .embed(&[TextToEmbed::new(&query)])
- .await?;
- let query_embedding = query_embeddings
- .pop()
- .ok_or_else(|| anyhow!("no embedding for query"))?;
- let mut workers = Vec::new();
- for _ in 0..cx.background_executor().num_cpus() {
- workers.push(Vec::<SearchResult>::new());
- }
-
- #[cfg(debug_assertions)]
- let search_start = std::time::Instant::now();
-
- cx.background_executor()
- .scoped(|cx| {
- for worker_results in workers.iter_mut() {
- cx.spawn(async {
- while let Ok((path, embedded_chunk)) = chunks_rx.recv().await {
- let score = embedded_chunk.embedding.similarity(&query_embedding);
- let ix = match worker_results.binary_search_by(|probe| {
- score.partial_cmp(&probe.score).unwrap_or(Ordering::Equal)
- }) {
- Ok(ix) | Err(ix) => ix,
- };
- worker_results.insert(
- ix,
- SearchResult {
- worktree: worktree.clone(),
- path: path.clone(),
- range: embedded_chunk.chunk.range.clone(),
- score,
- },
- );
- worker_results.truncate(limit);
- }
- });
- }
- })
- .await;
- scan_chunks.await?;
-
- let mut search_results = Vec::with_capacity(workers.len() * limit);
- for worker_results in workers {
- search_results.extend(worker_results);
- }
- search_results
- .sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
- search_results.truncate(limit);
- #[cfg(debug_assertions)]
- {
- let search_elapsed = search_start.elapsed();
- log::debug!(
- "searched {} entries in {:?}",
- search_results.len(),
- search_elapsed
- );
- let embedding_query_elapsed = embedding_query_start.elapsed();
- log::debug!("embedding query took {:?}", embedding_query_elapsed);
- }
-
- Ok(search_results)
- })
- }
-
fn debug(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let connection = self.db_connection.clone();
let db = self.db;
@@ -1093,9 +1099,10 @@ mod tests {
.update(|cx| {
let project_index = project_index.read(cx);
let query = "garbage in, garbage out";
- project_index.search(query, 4, cx)
+ project_index.search(query.into(), 4, cx)
})
- .await;
+ .await
+ .unwrap();
assert!(results.len() > 1, "should have found some results");
@@ -1112,9 +1119,10 @@ mod tests {
let content = cx
.update(|cx| {
let worktree = search_result.worktree.read(cx);
- let entry_abs_path = worktree.abs_path().join(search_result.path.clone());
+ let entry_abs_path = worktree.abs_path().join(&search_result.path);
let fs = project.read(cx).fs().clone();
- cx.spawn(|_| async move { fs.load(&entry_abs_path).await.unwrap() })
+ cx.background_executor()
+ .spawn(async move { fs.load(&entry_abs_path).await.unwrap() })
})
.await;