project search: Make cancellation smoother (#45406)

Piotr Osiewicz and Max Brunsfeld created

- **search: Make search cancellation more responsive (again)**
- **Fix project benchmarks build**
- **Less scoping and lifetimes for workers**

Related to #45300

Release Notes:

- Project search will consume less resources immediately after
cancellation.

---------

Co-authored-by: Max Brunsfeld <max@zed.dev>

Change summary

crates/agent/src/tools/grep_tool.rs                           |   9 
crates/collab/src/tests/integration_tests.rs                  |   2 
crates/collab/src/tests/random_project_collaboration_tests.rs |   2 
crates/project/src/project.rs                                 |  11 
crates/project/src/project_search.rs                          | 118 ++-
crates/project/src/project_tests.rs                           |   2 
crates/project/src/search.rs                                  |  48 +
crates/project/src/worktree_store.rs                          | 149 ----
crates/project_benchmarks/src/main.rs                         |   2 
crates/remote_server/src/headless_project.rs                  |   2 
crates/remote_server/src/remote_editing_tests.rs              |   4 
crates/search/src/project_search.rs                           |   6 
12 files changed, 133 insertions(+), 222 deletions(-)

Detailed changes

crates/agent/src/tools/grep_tool.rs 🔗

@@ -5,7 +5,7 @@ use futures::StreamExt;
 use gpui::{App, Entity, SharedString, Task};
 use language::{OffsetRangeExt, ParseStatus, Point};
 use project::{
-    Project, WorktreeSettings,
+    Project, SearchResults, WorktreeSettings,
     search::{SearchQuery, SearchResult},
 };
 use schemars::JsonSchema;
@@ -176,14 +176,17 @@ impl AgentTool for GrepTool {
 
         let project = self.project.downgrade();
         cx.spawn(async move |cx|  {
-            futures::pin_mut!(results);
+            // Keep the search alive for the duration of result iteration. Dropping this task is the
+            // cancellation mechanism; we intentionally do not detach it.
+            let SearchResults {rx, _task_handle}  = results;
+            futures::pin_mut!(rx);
 
             let mut output = String::new();
             let mut skips_remaining = input.offset;
             let mut matches_found = 0;
             let mut has_more_matches = false;
 
-            'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
+            'outer: while let Some(SearchResult::Buffer { buffer, ranges }) = rx.next().await {
                 if ranges.is_empty() {
                     continue;
                 }

crates/collab/src/tests/integration_tests.rs 🔗

@@ -5195,7 +5195,7 @@ async fn test_project_search(
             cx,
         )
     });
-    while let Ok(result) = search_rx.recv().await {
+    while let Ok(result) = search_rx.rx.recv().await {
         match result {
             SearchResult::Buffer { buffer, ranges } => {
                 results.entry(buffer).or_insert(ranges);

crates/collab/src/tests/random_project_collaboration_tests.rs 🔗

@@ -905,7 +905,7 @@ impl RandomizedTest for ProjectCollaborationTest {
                 drop(project);
                 let search = cx.executor().spawn(async move {
                     let mut results = HashMap::default();
-                    while let Ok(result) = search.recv().await {
+                    while let Ok(result) = search.rx.recv().await {
                         if let SearchResult::Buffer { buffer, ranges } = result {
                             results.entry(buffer).or_insert(ranges);
                         }

crates/project/src/project.rs 🔗

@@ -48,7 +48,7 @@ pub use git_store::{
     git_traversal::{ChildEntriesGitIter, GitEntry, GitEntryRef, GitTraversal},
 };
 pub use manifest_tree::ManifestTree;
-pub use project_search::Search;
+pub use project_search::{Search, SearchResults};
 
 use anyhow::{Context as _, Result, anyhow};
 use buffer_store::{BufferStore, BufferStoreEvent};
@@ -110,7 +110,6 @@ use rpc::{
 use search::{SearchInputKind, SearchQuery, SearchResult};
 use search_history::SearchHistory;
 use settings::{InvalidSettingsError, RegisterSetting, Settings, SettingsLocation, SettingsStore};
-use smol::channel::Receiver;
 use snippet::Snippet;
 pub use snippet_provider;
 use snippet_provider::SnippetProvider;
@@ -4143,7 +4142,11 @@ impl Project {
         searcher.into_handle(query, cx)
     }
 
-    pub fn search(&mut self, query: SearchQuery, cx: &mut Context<Self>) -> Receiver<SearchResult> {
+    pub fn search(
+        &mut self,
+        query: SearchQuery,
+        cx: &mut Context<Self>,
+    ) -> SearchResults<SearchResult> {
         self.search_impl(query, cx).results(cx)
     }
 
@@ -5028,7 +5031,7 @@ impl Project {
             buffer_ids: Vec::new(),
         };
 
-        while let Ok(buffer) = results.recv().await {
+        while let Ok(buffer) = results.rx.recv().await {
             this.update(&mut cx, |this, cx| {
                 let buffer_id = this.create_buffer_for_peer(&buffer, peer_id, cx);
                 response.buffer_ids.push(buffer_id.to_proto());

crates/project/src/project_search.rs 🔗

@@ -67,14 +67,22 @@ pub struct SearchResultsHandle {
     trigger_search: Box<dyn FnOnce(&mut App) -> Task<()> + Send + Sync>,
 }
 
+pub struct SearchResults<T> {
+    pub _task_handle: Task<()>,
+    pub rx: Receiver<T>,
+}
 impl SearchResultsHandle {
-    pub fn results(self, cx: &mut App) -> Receiver<SearchResult> {
-        (self.trigger_search)(cx).detach();
-        self.results
+    pub fn results(self, cx: &mut App) -> SearchResults<SearchResult> {
+        SearchResults {
+            _task_handle: (self.trigger_search)(cx),
+            rx: self.results,
+        }
     }
-    pub fn matching_buffers(self, cx: &mut App) -> Receiver<Entity<Buffer>> {
-        (self.trigger_search)(cx).detach();
-        self.matching_buffers
+    pub fn matching_buffers(self, cx: &mut App) -> SearchResults<Entity<Buffer>> {
+        SearchResults {
+            _task_handle: (self.trigger_search)(cx),
+            rx: self.matching_buffers,
+        }
     }
 }
 
@@ -165,6 +173,7 @@ impl Search {
                 unnamed_buffers.push(handle)
             };
         }
+        let open_buffers = Arc::new(open_buffers);
         let executor = cx.background_executor().clone();
         let (tx, rx) = unbounded();
         let (grab_buffer_snapshot_tx, grab_buffer_snapshot_rx) = unbounded();
@@ -215,7 +224,7 @@ impl Search {
                             ))
                             .boxed_local(),
                             Self::open_buffers(
-                                &self.buffer_store,
+                                self.buffer_store,
                                 get_buffer_for_full_scan_rx,
                                 grab_buffer_snapshot_tx,
                                 cx.clone(),
@@ -248,24 +257,26 @@ impl Search {
                             query: Some(query.to_proto()),
                             limit: self.limit as _,
                         });
+                        let weak_buffer_store = self.buffer_store.downgrade();
+                        let buffer_store = self.buffer_store;
                         let Ok(guard) = cx.update(|cx| {
                             Project::retain_remotely_created_models_impl(
                                 &models,
-                                &self.buffer_store,
+                                &buffer_store,
                                 &self.worktree_store,
                                 cx,
                             )
                         }) else {
                             return;
                         };
-                        let buffer_store = self.buffer_store.downgrade();
+
                         let issue_remote_buffers_request = cx
                             .spawn(async move |cx| {
                                 let _ = maybe!(async move {
                                     let response = request.await?;
                                     for buffer_id in response.buffer_ids {
                                         let buffer_id = BufferId::new(buffer_id)?;
-                                        let buffer = buffer_store
+                                        let buffer = weak_buffer_store
                                             .update(cx, |buffer_store, cx| {
                                                 buffer_store.wait_for_remote_buffer(buffer_id, cx)
                                             })?
@@ -289,22 +300,27 @@ impl Search {
 
                 let should_find_all_matches = !tx.is_closed();
 
-                let worker_pool = executor.scoped(|scope| {
-                    let num_cpus = executor.num_cpus();
+                let _executor = executor.clone();
+                let worker_pool = executor.spawn(async move {
+                    let num_cpus = _executor.num_cpus();
 
                     assert!(num_cpus > 0);
-                    for _ in 0..executor.num_cpus() - 1 {
-                        let worker = Worker {
-                            query: &query,
-                            open_buffers: &open_buffers,
-                            candidates: candidate_searcher.clone(),
-                            find_all_matches_rx: find_all_matches_rx.clone(),
-                        };
-                        scope.spawn(worker.run());
-                    }
+                    _executor
+                        .scoped(|scope| {
+                            for _ in 0..num_cpus - 1 {
+                                let worker = Worker {
+                                    query: query.clone(),
+                                    open_buffers: open_buffers.clone(),
+                                    candidates: candidate_searcher.clone(),
+                                    find_all_matches_rx: find_all_matches_rx.clone(),
+                                };
+                                scope.spawn(worker.run());
+                            }
 
-                    drop(find_all_matches_rx);
-                    drop(candidate_searcher);
+                            drop(find_all_matches_rx);
+                            drop(candidate_searcher);
+                        })
+                        .await;
                 });
 
                 let (sorted_matches_tx, sorted_matches_rx) = unbounded();
@@ -366,6 +382,7 @@ impl Search {
         async move |cx| {
             _ = maybe!(async move {
                 let gitignored_tracker = PathInclusionMatcher::new(query.clone());
+                let include_ignored = query.include_ignored();
                 for worktree in worktrees {
                     let (mut snapshot, worktree_settings) = worktree
                         .read_with(cx, |this, _| {
@@ -398,27 +415,28 @@ impl Search {
                         }
                         snapshot = worktree.read_with(cx, |this, _| this.snapshot())?;
                     }
+                    let tx = tx.clone();
+                    let results = results.clone();
+
                     cx.background_executor()
-                        .scoped(|scope| {
-                            scope.spawn(async {
-                                for entry in snapshot.files(query.include_ignored(), 0) {
-                                    let (should_scan_tx, should_scan_rx) = oneshot::channel();
-
-                                    let Ok(_) = tx
-                                        .send(InputPath {
-                                            entry: entry.clone(),
-                                            snapshot: snapshot.clone(),
-                                            should_scan_tx,
-                                        })
-                                        .await
-                                    else {
-                                        return;
-                                    };
-                                    if results.send(should_scan_rx).await.is_err() {
-                                        return;
-                                    };
-                                }
-                            })
+                        .spawn(async move {
+                            for entry in snapshot.files(include_ignored, 0) {
+                                let (should_scan_tx, should_scan_rx) = oneshot::channel();
+
+                                let Ok(_) = tx
+                                    .send(InputPath {
+                                        entry: entry.clone(),
+                                        snapshot: snapshot.clone(),
+                                        should_scan_tx,
+                                    })
+                                    .await
+                                else {
+                                    return;
+                                };
+                                if results.send(should_scan_rx).await.is_err() {
+                                    return;
+                                };
+                            }
                         })
                         .await;
                 }
@@ -452,7 +470,7 @@ impl Search {
 
     /// Background workers cannot open buffers by themselves, hence main thread will do it on their behalf.
     async fn open_buffers(
-        buffer_store: &Entity<BufferStore>,
+        buffer_store: Entity<BufferStore>,
         rx: Receiver<ProjectPath>,
         find_all_matches_tx: Sender<Entity<Buffer>>,
         mut cx: AsyncApp,
@@ -570,9 +588,9 @@ impl Search {
     }
 }
 
-struct Worker<'search> {
-    query: &'search SearchQuery,
-    open_buffers: &'search HashSet<ProjectEntryId>,
+struct Worker {
+    query: Arc<SearchQuery>,
+    open_buffers: Arc<HashSet<ProjectEntryId>>,
     candidates: FindSearchCandidates,
     /// Ok, we're back in background: run full scan & find all matches in a given buffer snapshot.
     /// Then, when you're done, share them via the channel you were given.
@@ -583,7 +601,7 @@ struct Worker<'search> {
     )>,
 }
 
-impl Worker<'_> {
+impl Worker {
     async fn run(self) {
         let (
             input_paths_rx,
@@ -614,7 +632,7 @@ impl Worker<'_> {
 
         loop {
             let handler = RequestHandler {
-                query: self.query,
+                query: &self.query,
                 open_entries: &self.open_buffers,
                 fs: fs.as_deref(),
                 confirm_contents_will_match_tx: &confirm_contents_will_match_tx,
@@ -701,7 +719,7 @@ impl RequestHandler<'_> {
                 return Ok(());
             }
 
-            if self.query.detect(file).unwrap_or(false) {
+            if self.query.detect(file).await.unwrap_or(false) {
                 // Yes, we should scan the whole file.
                 entry.should_scan_tx.send(entry.path).await?;
             }

crates/project/src/project_tests.rs 🔗

@@ -10402,7 +10402,7 @@ async fn search(
 ) -> Result<HashMap<String, Vec<Range<usize>>>> {
     let search_rx = project.update(cx, |project, cx| project.search(query, cx));
     let mut results = HashMap::default();
-    while let Ok(search_result) = search_rx.recv().await {
+    while let Ok(search_result) = search_rx.rx.recv().await {
         match search_result {
             SearchResult::Buffer { buffer, ranges } => {
                 results.entry(buffer).or_insert(ranges);

crates/project/src/search.rs 🔗

@@ -326,39 +326,65 @@ impl SearchQuery {
         }
     }
 
-    pub(crate) fn detect(
+    pub(crate) async fn detect(
         &self,
         mut reader: BufReader<Box<dyn Read + Send + Sync>>,
     ) -> Result<bool> {
+        let query_str = self.as_str();
+        let needle_len = query_str.len();
+        if needle_len == 0 {
+            return Ok(false);
+        }
         if self.as_str().is_empty() {
             return Ok(false);
         }
 
+        let mut text = String::new();
+        let mut bytes_read = 0;
+        // Yield from this function every 128 bytes scanned.
+        const YIELD_THRESHOLD: usize = 128;
         match self {
             Self::Text { search, .. } => {
-                let mat = search.stream_find_iter(reader).next();
-                match mat {
-                    Some(Ok(_)) => Ok(true),
-                    Some(Err(err)) => Err(err.into()),
-                    None => Ok(false),
+                if query_str.contains('\n') {
+                    reader.read_to_string(&mut text)?;
+                    Ok(search.is_match(&text))
+                } else {
+                    // Yield from this function every 128 bytes scanned.
+                    const YIELD_THRESHOLD: usize = 128;
+                    while reader.read_line(&mut text)? > 0 {
+                        if search.is_match(&text) {
+                            return Ok(true);
+                        }
+                        bytes_read += text.len();
+                        if bytes_read >= YIELD_THRESHOLD {
+                            bytes_read = 0;
+                            smol::future::yield_now().await;
+                        }
+                        text.clear();
+                    }
+                    Ok(false)
                 }
             }
             Self::Regex {
                 regex, multiline, ..
             } => {
                 if *multiline {
-                    let mut text = String::new();
                     if let Err(err) = reader.read_to_string(&mut text) {
                         Err(err.into())
                     } else {
-                        Ok(regex.find(&text)?.is_some())
+                        Ok(regex.is_match(&text)?)
                     }
                 } else {
-                    for line in reader.lines() {
-                        let line = line?;
-                        if regex.find(&line)?.is_some() {
+                    while reader.read_line(&mut text)? > 0 {
+                        if regex.is_match(&text)? {
                             return Ok(true);
                         }
+                        bytes_read += text.len();
+                        if bytes_read >= YIELD_THRESHOLD {
+                            bytes_read = 0;
+                            smol::future::yield_now().await;
+                        }
+                        text.clear();
                     }
                     Ok(false)
                 }

crates/project/src/worktree_store.rs 🔗

@@ -1,26 +1,19 @@
 use std::{
-    io::{BufRead, BufReader},
     path::{Path, PathBuf},
-    pin::pin,
     sync::{Arc, atomic::AtomicUsize},
 };
 
 use anyhow::{Context as _, Result, anyhow, bail};
-use collections::{HashMap, HashSet};
+use collections::HashMap;
 use fs::{Fs, copy_recursive};
-use futures::{FutureExt, SinkExt, future::Shared};
+use futures::{FutureExt, future::Shared};
 use gpui::{
     App, AppContext as _, AsyncApp, Context, Entity, EntityId, EventEmitter, Task, WeakEntity,
 };
-use postage::oneshot;
 use rpc::{
     AnyProtoClient, ErrorExt, TypedEnvelope,
     proto::{self, REMOTE_SERVER_PROJECT_ID},
 };
-use smol::{
-    channel::{Receiver, Sender},
-    stream::StreamExt,
-};
 use text::ReplicaId;
 use util::{
     ResultExt,
@@ -29,16 +22,10 @@ use util::{
 };
 use worktree::{
     CreatedEntry, Entry, ProjectEntryId, UpdatedEntriesSet, UpdatedGitRepositoriesSet, Worktree,
-    WorktreeId, WorktreeSettings,
+    WorktreeId,
 };
 
-use crate::{ProjectPath, search::SearchQuery};
-
-struct MatchingEntry {
-    worktree_root: Arc<Path>,
-    path: ProjectPath,
-    respond: oneshot::Sender<ProjectPath>,
-}
+use crate::ProjectPath;
 
 enum WorktreeStoreState {
     Local {
@@ -922,134 +909,6 @@ impl WorktreeStore {
         }
     }
 
-    /// search over all worktrees and return buffers that *might* match the search.
-    pub fn find_search_candidates(
-        &self,
-        query: SearchQuery,
-        limit: usize,
-        open_entries: HashSet<ProjectEntryId>,
-        fs: Arc<dyn Fs>,
-        cx: &Context<Self>,
-    ) -> Receiver<ProjectPath> {
-        let snapshots = self
-            .visible_worktrees(cx)
-            .filter_map(|tree| {
-                let tree = tree.read(cx);
-                Some((tree.snapshot(), tree.as_local()?.settings()))
-            })
-            .collect::<Vec<_>>();
-
-        let executor = cx.background_executor().clone();
-
-        // We want to return entries in the order they are in the worktrees, so we have one
-        // thread that iterates over the worktrees (and ignored directories) as necessary,
-        // and pushes a oneshot::Receiver to the output channel and a oneshot::Sender to the filter
-        // channel.
-        // We spawn a number of workers that take items from the filter channel and check the query
-        // against the version of the file on disk.
-        let (filter_tx, filter_rx) = smol::channel::bounded(64);
-        let (output_tx, output_rx) = smol::channel::bounded(64);
-        let (matching_paths_tx, matching_paths_rx) = smol::channel::unbounded();
-
-        let input = cx.background_spawn({
-            let fs = fs.clone();
-            let query = query.clone();
-            async move {
-                Self::find_candidate_paths(
-                    fs,
-                    snapshots,
-                    open_entries,
-                    query,
-                    filter_tx,
-                    output_tx,
-                )
-                .await
-                .log_err();
-            }
-        });
-        const MAX_CONCURRENT_FILE_SCANS: usize = 64;
-        let filters = cx.background_spawn(async move {
-            let fs = &fs;
-            let query = &query;
-            executor
-                .scoped(move |scope| {
-                    for _ in 0..MAX_CONCURRENT_FILE_SCANS {
-                        let filter_rx = filter_rx.clone();
-                        scope.spawn(async move {
-                            Self::filter_paths(fs, filter_rx, query)
-                                .await
-                                .log_with_level(log::Level::Debug);
-                        })
-                    }
-                })
-                .await;
-        });
-        cx.background_spawn(async move {
-            let mut matched = 0;
-            while let Ok(mut receiver) = output_rx.recv().await {
-                let Some(path) = receiver.next().await else {
-                    continue;
-                };
-                let Ok(_) = matching_paths_tx.send(path).await else {
-                    break;
-                };
-                matched += 1;
-                if matched == limit {
-                    break;
-                }
-            }
-            drop(input);
-            drop(filters);
-        })
-        .detach();
-        matching_paths_rx
-    }
-
-    async fn find_candidate_paths(
-        _: Arc<dyn Fs>,
-        _: Vec<(worktree::Snapshot, WorktreeSettings)>,
-        _: HashSet<ProjectEntryId>,
-        _: SearchQuery,
-        _: Sender<MatchingEntry>,
-        _: Sender<oneshot::Receiver<ProjectPath>>,
-    ) -> Result<()> {
-        Ok(())
-    }
-
-    async fn filter_paths(
-        fs: &Arc<dyn Fs>,
-        input: Receiver<MatchingEntry>,
-        query: &SearchQuery,
-    ) -> Result<()> {
-        let mut input = pin!(input);
-        while let Some(mut entry) = input.next().await {
-            let abs_path = entry.worktree_root.join(entry.path.path.as_std_path());
-            let Some(file) = fs.open_sync(&abs_path).await.log_err() else {
-                continue;
-            };
-
-            let mut file = BufReader::new(file);
-            let file_start = file.fill_buf()?;
-
-            if let Err(Some(starting_position)) =
-                std::str::from_utf8(file_start).map_err(|e| e.error_len())
-            {
-                // Before attempting to match the file content, throw away files that have invalid UTF-8 sequences early on;
-                // That way we can still match files in a streaming fashion without having look at "obviously binary" files.
-                log::debug!(
-                    "Invalid UTF-8 sequence in file {abs_path:?} at byte position {starting_position}"
-                );
-                continue;
-            }
-
-            if query.detect(file).unwrap_or(false) {
-                entry.respond.send(entry.path).await?
-            }
-        }
-
-        Ok(())
-    }
-
     pub async fn handle_create_project_entry(
         this: Entity<Self>,
         envelope: TypedEnvelope<proto::CreateProjectEntry>,

crates/project_benchmarks/src/main.rs 🔗

@@ -107,7 +107,7 @@ fn main() -> Result<(), anyhow::Error> {
                     .unwrap();
                 let mut matched_files = 0;
                 let mut matched_chunks = 0;
-                while let Ok(match_result) = matches.recv().await {
+                while let Ok(match_result) = matches.rx.recv().await {
                     if first_match.is_none() {
                         let time = timer.elapsed();
                         first_match = Some(time);

crates/remote_server/src/headless_project.rs 🔗

@@ -793,7 +793,7 @@ impl HeadlessProject {
 
         let buffer_store = this.read_with(&cx, |this, _| this.buffer_store.clone())?;
 
-        while let Ok(buffer) = results.recv().await {
+        while let Ok(buffer) = results.rx.recv().await {
             let buffer_id = buffer.read_with(&cx, |this, _| this.remote_id())?;
             response.buffer_ids.push(buffer_id.to_proto());
             buffer_store

crates/remote_server/src/remote_editing_tests.rs 🔗

@@ -211,7 +211,7 @@ async fn test_remote_project_search(cx: &mut TestAppContext, server_cx: &mut Tes
             )
         });
 
-        let first_response = receiver.recv().await.unwrap();
+        let first_response = receiver.rx.recv().await.unwrap();
         let SearchResult::Buffer { buffer, .. } = first_response else {
             panic!("incorrect result");
         };
@@ -222,7 +222,7 @@ async fn test_remote_project_search(cx: &mut TestAppContext, server_cx: &mut Tes
             )
         });
 
-        assert!(receiver.recv().await.is_err());
+        assert!(receiver.rx.recv().await.is_err());
         buffer
     }
 

crates/search/src/project_search.rs 🔗

@@ -25,7 +25,7 @@ use itertools::Itertools;
 use language::{Buffer, Language};
 use menu::Confirm;
 use project::{
-    Project, ProjectPath,
+    Project, ProjectPath, SearchResults,
     search::{SearchInputKind, SearchQuery},
     search_history::SearchHistoryCursor,
 };
@@ -326,7 +326,9 @@ impl ProjectSearch {
         self.active_query = Some(query);
         self.match_ranges.clear();
         self.pending_search = Some(cx.spawn(async move |project_search, cx| {
-            let mut matches = pin!(search.ready_chunks(1024));
+            let SearchResults { rx, _task_handle } = search;
+
+            let mut matches = pin!(rx.ready_chunks(1024));
             project_search
                 .update(cx, |project_search, cx| {
                     project_search.match_ranges.clear();