Use an `executor::Background` in `AppContext::thread_pool`

Antonio Scandurra created

Change summary

Cargo.lock                |   8 -
Cargo.toml                |  10 -
gpui/Cargo.toml           |   1 
gpui/src/app.rs           |  13 +
gpui/src/lib.rs           |   1 
gpui_macros/src/lib.rs    |   2 
scoped_pool/Cargo.toml    |   8 -
scoped_pool/src/lib.rs    | 188 -----------------------------------------
zed/src/file_finder.rs    |  14 +-
zed/src/worktree.rs       |  58 ++++++------
zed/src/worktree/fuzzy.rs |  16 +-
11 files changed, 54 insertions(+), 265 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2173,7 +2173,6 @@ dependencies = [
  "rand 0.8.3",
  "replace_with",
  "resvg",
- "scoped-pool",
  "seahash",
  "serde 1.0.125",
  "serde_json 1.0.64",
@@ -4164,13 +4163,6 @@ dependencies = [
  "winapi 0.3.9",
 ]
 
-[[package]]
-name = "scoped-pool"
-version = "0.0.1"
-dependencies = [
- "crossbeam-channel",
-]
-
 [[package]]
 name = "scoped-tls"
 version = "1.0.0"

Cargo.toml 🔗

@@ -1,13 +1,5 @@
 [workspace]
-members = [
-    "fsevent",
-    "gpui",
-    "gpui_macros",
-    "scoped_pool",
-    "server",
-    "zed",
-    "zed-rpc"
-]
+members = ["fsevent", "gpui", "gpui_macros", "server", "zed", "zed-rpc"]
 
 [patch.crates-io]
 async-task = { git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e" }

gpui/Cargo.toml 🔗

@@ -20,7 +20,6 @@ postage = { version = "0.4.1", features = ["futures-traits"] }
 rand = "0.8.3"
 replace_with = "0.1.7"
 resvg = "0.14"
-scoped-pool = { path = "../scoped_pool" }
 seahash = "4.1"
 serde = { version = "1.0.125", features = ["derive"] }
 serde_json = "1.0.64"

gpui/src/app.rs 🔗

@@ -123,6 +123,7 @@ impl App {
         let cx = Rc::new(RefCell::new(MutableAppContext::new(
             foreground,
             Arc::new(executor::Background::new()),
+            Arc::new(executor::Background::new()),
             Arc::new(platform),
             Rc::new(foreground_platform),
             (),
@@ -139,6 +140,7 @@ impl App {
         let app = Self(Rc::new(RefCell::new(MutableAppContext::new(
             foreground,
             Arc::new(executor::Background::new()),
+            Arc::new(executor::Background::new()),
             platform.clone(),
             foreground_platform.clone(),
             asset_source,
@@ -245,6 +247,7 @@ impl TestAppContext {
     pub fn new(
         foreground: Rc<executor::Foreground>,
         background: Arc<executor::Background>,
+        thread_pool: Arc<executor::Background>,
         first_entity_id: usize,
     ) -> Self {
         let platform = Arc::new(platform::test::platform());
@@ -252,6 +255,7 @@ impl TestAppContext {
         let mut cx = MutableAppContext::new(
             foreground.clone(),
             background,
+            thread_pool,
             platform,
             foreground_platform.clone(),
             (),
@@ -590,6 +594,7 @@ impl MutableAppContext {
     fn new(
         foreground: Rc<executor::Foreground>,
         background: Arc<executor::Background>,
+        thread_pool: Arc<executor::Background>,
         platform: Arc<dyn platform::Platform>,
         foreground_platform: Rc<dyn platform::ForegroundPlatform>,
         asset_source: impl AssetSource,
@@ -607,7 +612,7 @@ impl MutableAppContext {
                 values: Default::default(),
                 ref_counts: Arc::new(Mutex::new(RefCounts::default())),
                 background,
-                thread_pool: scoped_pool::Pool::new(num_cpus::get(), "app"),
+                thread_pool,
                 font_cache: Arc::new(FontCache::new(fonts)),
             },
             actions: HashMap::new(),
@@ -1485,7 +1490,7 @@ pub struct AppContext {
     values: RwLock<HashMap<(TypeId, usize), Box<dyn Any>>>,
     background: Arc<executor::Background>,
     ref_counts: Arc<Mutex<RefCounts>>,
-    thread_pool: scoped_pool::Pool,
+    thread_pool: Arc<executor::Background>,
     font_cache: Arc<FontCache>,
 }
 
@@ -1530,7 +1535,7 @@ impl AppContext {
         &self.font_cache
     }
 
-    pub fn thread_pool(&self) -> &scoped_pool::Pool {
+    pub fn thread_pool(&self) -> &Arc<executor::Background> {
         &self.thread_pool
     }
 
@@ -1716,7 +1721,7 @@ impl<'a, T: Entity> ModelContext<'a, T> {
         &self.app.cx.background
     }
 
-    pub fn thread_pool(&self) -> &scoped_pool::Pool {
+    pub fn thread_pool(&self) -> &Arc<executor::Background> {
         &self.app.cx.thread_pool
     }
 

gpui/src/lib.rs 🔗

@@ -30,4 +30,3 @@ pub use presenter::{
     AfterLayoutContext, Axis, DebugContext, EventContext, LayoutContext, PaintContext,
     SizeConstraint, Vector2FExt,
 };
-pub use scoped_pool;

gpui_macros/src/lib.rs 🔗

@@ -60,7 +60,7 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
     let inner_fn_args = (0..inner_fn.sig.inputs.len())
         .map(|i| {
             let first_entity_id = i * 100_000;
-            quote!(#namespace::TestAppContext::new(foreground.clone(), background.clone(), #first_entity_id),)
+            quote!(#namespace::TestAppContext::new(foreground.clone(), background.clone(), background.clone(), #first_entity_id),)
         })
         .collect::<proc_macro2::TokenStream>();
 

scoped_pool/Cargo.toml 🔗

@@ -1,8 +0,0 @@
-[package]
-name = "scoped-pool"
-version = "0.0.1"
-license = "MIT"
-edition = "2018"
-
-[dependencies]
-crossbeam-channel = "0.5"

scoped_pool/src/lib.rs 🔗

@@ -1,188 +0,0 @@
-use crossbeam_channel as chan;
-use std::{marker::PhantomData, mem::transmute, thread};
-
-#[derive(Clone)]
-pub struct Pool {
-    req_tx: chan::Sender<Request>,
-    thread_count: usize,
-}
-
-pub struct Scope<'a> {
-    req_count: usize,
-    req_tx: chan::Sender<Request>,
-    resp_tx: chan::Sender<()>,
-    resp_rx: chan::Receiver<()>,
-    phantom: PhantomData<&'a ()>,
-}
-
-struct Request {
-    callback: Box<dyn FnOnce() + Send + 'static>,
-    resp_tx: chan::Sender<()>,
-}
-
-impl Pool {
-    pub fn new(thread_count: usize, name: impl AsRef<str>) -> Self {
-        let (req_tx, req_rx) = chan::unbounded();
-        for i in 0..thread_count {
-            thread::Builder::new()
-                .name(format!("scoped_pool {} {}", name.as_ref(), i))
-                .spawn({
-                    let req_rx = req_rx.clone();
-                    move || loop {
-                        match req_rx.recv() {
-                            Err(_) => break,
-                            Ok(Request { callback, resp_tx }) => {
-                                callback();
-                                resp_tx.send(()).ok();
-                            }
-                        }
-                    }
-                })
-                .expect("scoped_pool: failed to spawn thread");
-        }
-        Self {
-            req_tx,
-            thread_count,
-        }
-    }
-
-    pub fn thread_count(&self) -> usize {
-        self.thread_count
-    }
-
-    pub fn scoped<'scope, F, R>(&self, scheduler: F) -> R
-    where
-        F: FnOnce(&mut Scope<'scope>) -> R,
-    {
-        let (resp_tx, resp_rx) = chan::bounded(1);
-        let mut scope = Scope {
-            resp_tx,
-            resp_rx,
-            req_count: 0,
-            phantom: PhantomData,
-            req_tx: self.req_tx.clone(),
-        };
-        let result = scheduler(&mut scope);
-        scope.wait();
-        result
-    }
-}
-
-impl<'scope> Scope<'scope> {
-    pub fn execute<F>(&mut self, callback: F)
-    where
-        F: FnOnce() + Send + 'scope,
-    {
-        // Transmute the callback's lifetime to be 'static. This is safe because in ::wait,
-        // we block until all the callbacks have been called and dropped.
-        let callback = unsafe {
-            transmute::<Box<dyn FnOnce() + Send + 'scope>, Box<dyn FnOnce() + Send + 'static>>(
-                Box::new(callback),
-            )
-        };
-
-        self.req_count += 1;
-        self.req_tx
-            .send(Request {
-                callback,
-                resp_tx: self.resp_tx.clone(),
-            })
-            .unwrap();
-    }
-
-    fn wait(&self) {
-        for _ in 0..self.req_count {
-            self.resp_rx.recv().unwrap();
-        }
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use std::sync::{Arc, Mutex};
-
-    #[test]
-    fn test_execute() {
-        let pool = Pool::new(3, "test");
-
-        {
-            let vec = Mutex::new(Vec::new());
-            pool.scoped(|scope| {
-                for _ in 0..3 {
-                    scope.execute(|| {
-                        for i in 0..5 {
-                            vec.lock().unwrap().push(i);
-                        }
-                    });
-                }
-            });
-
-            let mut vec = vec.into_inner().unwrap();
-            vec.sort_unstable();
-            assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
-        }
-    }
-
-    #[test]
-    fn test_clone_send_and_execute() {
-        let pool = Pool::new(3, "test");
-
-        let mut threads = Vec::new();
-        for _ in 0..3 {
-            threads.push(thread::spawn({
-                let pool = pool.clone();
-                move || {
-                    let vec = Mutex::new(Vec::new());
-                    pool.scoped(|scope| {
-                        for _ in 0..3 {
-                            scope.execute(|| {
-                                for i in 0..5 {
-                                    vec.lock().unwrap().push(i);
-                                }
-                            });
-                        }
-                    });
-                    let mut vec = vec.into_inner().unwrap();
-                    vec.sort_unstable();
-                    assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
-                }
-            }));
-        }
-
-        for thread in threads {
-            thread.join().unwrap();
-        }
-    }
-
-    #[test]
-    fn test_share_and_execute() {
-        let pool = Arc::new(Pool::new(3, "test"));
-
-        let mut threads = Vec::new();
-        for _ in 0..3 {
-            threads.push(thread::spawn({
-                let pool = pool.clone();
-                move || {
-                    let vec = Mutex::new(Vec::new());
-                    pool.scoped(|scope| {
-                        for _ in 0..3 {
-                            scope.execute(|| {
-                                for i in 0..5 {
-                                    vec.lock().unwrap().push(i);
-                                }
-                            });
-                        }
-                    });
-                    let mut vec = vec.into_inner().unwrap();
-                    vec.sort_unstable();
-                    assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
-                }
-            }));
-        }
-
-        for thread in threads {
-            thread.join().unwrap();
-        }
-    }
-}

zed/src/file_finder.rs 🔗

@@ -403,7 +403,7 @@ impl FileFinder {
         self.cancel_flag.store(true, atomic::Ordering::Relaxed);
         self.cancel_flag = Arc::new(AtomicBool::new(false));
         let cancel_flag = self.cancel_flag.clone();
-        let background_task = cx.background_executor().spawn(async move {
+        Some(cx.spawn(|this, mut cx| async move {
             let include_root_name = snapshots.len() > 1;
             let matches = match_paths(
                 snapshots.iter(),
@@ -414,14 +414,12 @@ impl FileFinder {
                 100,
                 cancel_flag.clone(),
                 pool,
-            );
+            )
+            .await;
             let did_cancel = cancel_flag.load(atomic::Ordering::Relaxed);
-            (search_id, did_cancel, query, matches)
-        });
-
-        Some(cx.spawn(|this, mut cx| async move {
-            let matches = background_task.await;
-            this.update(&mut cx, |this, cx| this.update_matches(matches, cx));
+            this.update(&mut cx, |this, cx| {
+                this.update_matches((search_id, did_cancel, query, matches), cx)
+            });
         }))
     }
 

zed/src/worktree.rs 🔗

@@ -550,15 +550,11 @@ impl Worktree {
         let (mut tree, scan_states_tx) = LocalWorktree::new(path, languages, fs.clone(), cx);
         let abs_path = tree.snapshot.abs_path.clone();
         let background_snapshot = tree.background_snapshot.clone();
-        let background = if fs.is_fake() {
-            cx.background().clone()
-        } else {
-            Arc::new(executor::Background::new())
-        };
+        let thread_pool = cx.thread_pool().clone();
         tree._background_scanner_task = Some(cx.background().spawn(async move {
             let events = fs.watch(&abs_path, Duration::from_millis(100)).await;
             let scanner =
-                BackgroundScanner::new(background_snapshot, scan_states_tx, fs, background);
+                BackgroundScanner::new(background_snapshot, scan_states_tx, fs, thread_pool);
             scanner.run(events).await;
         }));
         Worktree::Local(tree)
@@ -3017,36 +3013,40 @@ mod tests {
 
         cx.read(|cx| tree.read(cx).as_local().unwrap().scan_complete())
             .await;
-        cx.read(|cx| {
+        let snapshot = cx.read(|cx| {
             let tree = tree.read(cx);
             assert_eq!(tree.file_count(), 5);
-
             assert_eq!(
                 tree.inode_for_path("fennel/grape"),
                 tree.inode_for_path("finnochio/grape")
             );
 
-            let results = match_paths(
-                Some(tree.snapshot()).iter(),
-                "bna",
-                false,
-                false,
-                false,
-                10,
-                Default::default(),
-                cx.thread_pool().clone(),
-            )
-            .into_iter()
-            .map(|result| result.path)
-            .collect::<Vec<Arc<Path>>>();
-            assert_eq!(
-                results,
-                vec![
-                    PathBuf::from("banana/carrot/date").into(),
-                    PathBuf::from("banana/carrot/endive").into(),
-                ]
-            );
-        })
+            tree.snapshot()
+        });
+        let results = cx
+            .read(|cx| {
+                match_paths(
+                    Some(&snapshot).into_iter(),
+                    "bna",
+                    false,
+                    false,
+                    false,
+                    10,
+                    Default::default(),
+                    cx.thread_pool().clone(),
+                )
+            })
+            .await;
+        assert_eq!(
+            results
+                .into_iter()
+                .map(|result| result.path)
+                .collect::<Vec<Arc<Path>>>(),
+            vec![
+                PathBuf::from("banana/carrot/date").into(),
+                PathBuf::from("banana/carrot/endive").into(),
+            ]
+        );
     }
 
     #[gpui::test]

zed/src/worktree/fuzzy.rs 🔗

@@ -1,6 +1,6 @@
 use super::{char_bag::CharBag, EntryKind, Snapshot};
 use crate::util;
-use gpui::scoped_pool;
+use gpui::executor;
 use std::{
     cmp::{max, min, Ordering},
     path::Path,
@@ -51,7 +51,7 @@ impl Ord for PathMatch {
     }
 }
 
-pub fn match_paths<'a, T>(
+pub async fn match_paths<'a, T>(
     snapshots: T,
     query: &str,
     include_root_name: bool,
@@ -59,7 +59,7 @@ pub fn match_paths<'a, T>(
     smart_case: bool,
     max_results: usize,
     cancel_flag: Arc<AtomicBool>,
-    pool: scoped_pool::Pool,
+    pool: Arc<executor::Background>,
 ) -> Vec<PathMatch>
 where
     T: Clone + Send + Iterator<Item = &'a Snapshot> + 'a,
@@ -71,15 +71,14 @@ where
     let query = &query;
     let query_chars = CharBag::from(&lowercase_query[..]);
 
-    let cpus = num_cpus::get();
     let path_count: usize = if include_ignored {
         snapshots.clone().map(Snapshot::file_count).sum()
     } else {
         snapshots.clone().map(Snapshot::visible_file_count).sum()
     };
 
-    let segment_size = (path_count + cpus - 1) / cpus;
-    let mut segment_results = (0..cpus)
+    let segment_size = (path_count + pool.threads() - 1) / pool.threads();
+    let mut segment_results = (0..pool.threads())
         .map(|_| Vec::with_capacity(max_results))
         .collect::<Vec<_>>();
 
@@ -87,7 +86,7 @@ where
         for (segment_idx, results) in segment_results.iter_mut().enumerate() {
             let snapshots = snapshots.clone();
             let cancel_flag = &cancel_flag;
-            scope.execute(move || {
+            scope.spawn(async move {
                 let segment_start = segment_idx * segment_size;
                 let segment_end = segment_start + segment_size;
 
@@ -152,7 +151,8 @@ where
                 }
             })
         }
-    });
+    })
+    .await;
 
     let mut results = Vec::new();
     for segment_result in segment_results {