Use our own scoped_pool implementation

Max Brunsfeld created

Change summary

Cargo.lock             |  28 -----
Cargo.toml             |   2 
gpui/Cargo.toml        |   2 
gpui/src/app.rs        |   2 
scoped_pool/Cargo.toml |   8 +
scoped_pool/src/lib.rs | 188 ++++++++++++++++++++++++++++++++++++++++++++
zed/src/worktree.rs    |  12 --
7 files changed, 205 insertions(+), 37 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -448,12 +448,6 @@ dependencies = [
  "cfg-if 1.0.0",
 ]
 
-[[package]]
-name = "crossbeam"
-version = "0.2.12"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bd66663db5a988098a89599d4857919b3acf7f61402e61365acfd3919857b9be"
-
 [[package]]
 name = "crossbeam-channel"
 version = "0.4.4"
@@ -1062,7 +1056,7 @@ version = "0.4.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "dd96ffd135b2fd7b973ac026d28085defbe8983df057ced3eb4f2130b0831312"
 dependencies = [
- "scopeguard 1.1.0",
+ "scopeguard",
 ]
 
 [[package]]
@@ -1714,13 +1708,9 @@ dependencies = [
 
 [[package]]
 name = "scoped-pool"
-version = "1.0.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "817a3a15e704545ce59ed2b5c60a5d32bda4d7869befb8b36667b658a6c00b43"
+version = "0.0.1"
 dependencies = [
- "crossbeam",
- "scopeguard 0.1.2",
- "variance",
+ "crossbeam-channel 0.5.0",
 ]
 
 [[package]]
@@ -1729,12 +1719,6 @@ version = "1.0.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "ea6a9290e3c9cf0f18145ef7ffa62d68ee0bf5fcd651017e586dc7fd5da448c2"
 
-[[package]]
-name = "scopeguard"
-version = "0.1.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "59a076157c1e2dc561d8de585151ee6965d910dd4dcb5dabb7ae3e83981a6c57"
-
 [[package]]
 name = "scopeguard"
 version = "1.1.0"
@@ -2138,12 +2122,6 @@ dependencies = [
  "xmlwriter",
 ]
 
-[[package]]
-name = "variance"
-version = "0.1.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3abfc2be1fb59663871379ea884fd81de80c496f2274e021c01d6fe56cd77b05"
-
 [[package]]
 name = "vec-arena"
 version = "1.0.0"

Cargo.toml 🔗

@@ -1,5 +1,5 @@
 [workspace]
-members = ["zed", "gpui", "fsevent"]
+members = ["zed", "gpui", "fsevent", "scoped_pool"]
 
 [patch.crates-io]
 async-task = {git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e"}

gpui/Cargo.toml 🔗

@@ -18,7 +18,7 @@ postage = {version = "0.4.1", features = ["futures-traits"]}
 rand = "0.8.3"
 replace_with = "0.1.7"
 resvg = "0.14"
-scoped-pool = "1.0.0"
+scoped-pool = {path = "../scoped_pool"}
 seahash = "4.1"
 serde = {version = "1.0.125", features = ["derive"]}
 serde_json = "1.0.64"

gpui/src/app.rs 🔗

@@ -411,7 +411,7 @@ impl MutableAppContext {
                 windows: HashMap::new(),
                 ref_counts: Arc::new(Mutex::new(RefCounts::default())),
                 background: Arc::new(executor::Background::new()),
-                thread_pool: scoped_pool::Pool::new(num_cpus::get()),
+                thread_pool: scoped_pool::Pool::new(num_cpus::get(), "app"),
             },
             actions: HashMap::new(),
             global_actions: HashMap::new(),

scoped_pool/Cargo.toml 🔗

@@ -0,0 +1,8 @@
+[package]
+name = "scoped-pool"
+version = "0.0.1"
+license = "MIT"
+edition = "2018"
+
+[dependencies]
+crossbeam-channel = "0.5"

scoped_pool/src/lib.rs 🔗

@@ -0,0 +1,188 @@
+use crossbeam_channel as chan;
+use std::{marker::PhantomData, mem::transmute, thread};
+
+#[derive(Clone)]
+pub struct Pool {
+    req_tx: chan::Sender<Request>,
+    thread_count: usize,
+}
+
+pub struct Scope<'a> {
+    req_count: usize,
+    req_tx: chan::Sender<Request>,
+    resp_tx: chan::Sender<()>,
+    resp_rx: chan::Receiver<()>,
+    phantom: PhantomData<&'a ()>,
+}
+
+struct Request {
+    callback: Box<dyn FnOnce() + Send + 'static>,
+    resp_tx: chan::Sender<()>,
+}
+
+impl Pool {
+    pub fn new(thread_count: usize, name: &str) -> Self {
+        let (req_tx, req_rx) = chan::unbounded();
+        for i in 0..thread_count {
+            thread::Builder::new()
+                .name(format!("scoped_pool {} {}", name, i))
+                .spawn({
+                    let req_rx = req_rx.clone();
+                    move || loop {
+                        match req_rx.recv() {
+                            Err(_) => break,
+                            Ok(Request { callback, resp_tx }) => {
+                                callback();
+                                resp_tx.send(()).ok();
+                            }
+                        }
+                    }
+                })
+                .expect("scoped_pool: failed to spawn thread");
+        }
+        Self {
+            req_tx,
+            thread_count,
+        }
+    }
+
+    pub fn thread_count(&self) -> usize {
+        self.thread_count
+    }
+
+    pub fn scoped<'scope, F, R>(&self, scheduler: F) -> R
+    where
+        F: FnOnce(&mut Scope<'scope>) -> R,
+    {
+        let (resp_tx, resp_rx) = chan::bounded(1);
+        let mut scope = Scope {
+            resp_tx,
+            resp_rx,
+            req_count: 0,
+            phantom: PhantomData,
+            req_tx: self.req_tx.clone(),
+        };
+        let result = scheduler(&mut scope);
+        scope.wait();
+        result
+    }
+}
+
+impl<'scope> Scope<'scope> {
+    pub fn execute<F>(&mut self, callback: F)
+    where
+        F: FnOnce() + Send + 'scope,
+    {
+        // Transmute the callback's lifetime to be 'static. This is safe because in ::wait,
+        // we block until all the callbacks have been called and dropped.
+        let callback = unsafe {
+            transmute::<Box<dyn FnOnce() + Send + 'scope>, Box<dyn FnOnce() + Send + 'static>>(
+                Box::new(callback),
+            )
+        };
+
+        self.req_count += 1;
+        self.req_tx
+            .send(Request {
+                callback,
+                resp_tx: self.resp_tx.clone(),
+            })
+            .unwrap();
+    }
+
+    fn wait(&self) {
+        for _ in 0..self.req_count {
+            self.resp_rx.recv().unwrap();
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use std::sync::{Arc, Mutex};
+
+    #[test]
+    fn test_execute() {
+        let pool = Pool::new(3, "test");
+
+        {
+            let vec = Mutex::new(Vec::new());
+            pool.scoped(|scope| {
+                for _ in 0..3 {
+                    scope.execute(|| {
+                        for i in 0..5 {
+                            vec.lock().unwrap().push(i);
+                        }
+                    });
+                }
+            });
+
+            let mut vec = vec.into_inner().unwrap();
+            vec.sort_unstable();
+            assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
+        }
+    }
+
+    #[test]
+    fn test_clone_send_and_execute() {
+        let pool = Pool::new(3, "test");
+
+        let mut threads = Vec::new();
+        for _ in 0..3 {
+            threads.push(thread::spawn({
+                let pool = pool.clone();
+                move || {
+                    let vec = Mutex::new(Vec::new());
+                    pool.scoped(|scope| {
+                        for _ in 0..3 {
+                            scope.execute(|| {
+                                for i in 0..5 {
+                                    vec.lock().unwrap().push(i);
+                                }
+                            });
+                        }
+                    });
+                    let mut vec = vec.into_inner().unwrap();
+                    vec.sort_unstable();
+                    assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
+                }
+            }));
+        }
+
+        for thread in threads {
+            thread.join().unwrap();
+        }
+    }
+
+    #[test]
+    fn test_share_and_execute() {
+        let pool = Arc::new(Pool::new(3, "test"));
+
+        let mut threads = Vec::new();
+        for _ in 0..3 {
+            threads.push(thread::spawn({
+                let pool = pool.clone();
+                move || {
+                    let vec = Mutex::new(Vec::new());
+                    pool.scoped(|scope| {
+                        for _ in 0..3 {
+                            scope.execute(|| {
+                                for i in 0..5 {
+                                    vec.lock().unwrap().push(i);
+                                }
+                            });
+                        }
+                    });
+                    let mut vec = vec.into_inner().unwrap();
+                    vec.sort_unstable();
+                    assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
+                }
+            }));
+        }
+
+        for thread in threads {
+            thread.join().unwrap();
+        }
+    }
+}

zed/src/worktree.rs 🔗

@@ -500,7 +500,7 @@ impl BackgroundScanner {
         Self {
             snapshot: Mutex::new(snapshot),
             notify,
-            thread_pool: scoped_pool::Pool::new(16),
+            thread_pool: scoped_pool::Pool::new(16, "background-scanner"),
         }
     }
 
@@ -592,7 +592,7 @@ impl BackgroundScanner {
             drop(tx);
 
             let mut results = Vec::new();
-            results.resize_with(self.thread_pool.workers(), || Ok(()));
+            results.resize_with(self.thread_pool.thread_count(), || Ok(()));
             self.thread_pool.scoped(|pool| {
                 for result in &mut results {
                     pool.execute(|| {
@@ -762,7 +762,7 @@ impl BackgroundScanner {
         // Scan any directories that were created as part of this event batch.
         drop(scan_queue_tx);
         self.thread_pool.scoped(|pool| {
-            for _ in 0..self.thread_pool.workers() {
+            for _ in 0..self.thread_pool.thread_count() {
                 pool.execute(|| {
                     while let Ok(job) = scan_queue_rx.recv() {
                         if let Err(err) = job.and_then(|job| self.scan_dir(job)) {
@@ -844,12 +844,6 @@ impl BackgroundScanner {
     }
 }
 
-impl Drop for BackgroundScanner {
-    fn drop(&mut self) {
-        self.thread_pool.shutdown();
-    }
-}
-
 struct ScanJob {
     inode: u64,
     path: Arc<Path>,