lib.rs

  1use crossbeam_channel as chan;
  2use std::{marker::PhantomData, mem::transmute, thread};
  3
  4#[derive(Clone)]
  5pub struct Pool {
  6    req_tx: chan::Sender<Request>,
  7    thread_count: usize,
  8}
  9
 10pub struct Scope<'a> {
 11    req_count: usize,
 12    req_tx: chan::Sender<Request>,
 13    resp_tx: chan::Sender<()>,
 14    resp_rx: chan::Receiver<()>,
 15    phantom: PhantomData<&'a ()>,
 16}
 17
 18struct Request {
 19    callback: Box<dyn FnOnce() + Send + 'static>,
 20    resp_tx: chan::Sender<()>,
 21}
 22
 23impl Pool {
 24    pub fn new(thread_count: usize, name: impl AsRef<str>) -> Self {
 25        let (req_tx, req_rx) = chan::unbounded();
 26        for i in 0..thread_count {
 27            thread::Builder::new()
 28                .name(format!("scoped_pool {} {}", name.as_ref(), i))
 29                .spawn({
 30                    let req_rx = req_rx.clone();
 31                    move || loop {
 32                        match req_rx.recv() {
 33                            Err(_) => break,
 34                            Ok(Request { callback, resp_tx }) => {
 35                                callback();
 36                                resp_tx.send(()).ok();
 37                            }
 38                        }
 39                    }
 40                })
 41                .expect("scoped_pool: failed to spawn thread");
 42        }
 43        Self {
 44            req_tx,
 45            thread_count,
 46        }
 47    }
 48
 49    pub fn thread_count(&self) -> usize {
 50        self.thread_count
 51    }
 52
 53    pub fn scoped<'scope, F, R>(&self, scheduler: F) -> R
 54    where
 55        F: FnOnce(&mut Scope<'scope>) -> R,
 56    {
 57        let (resp_tx, resp_rx) = chan::bounded(1);
 58        let mut scope = Scope {
 59            resp_tx,
 60            resp_rx,
 61            req_count: 0,
 62            phantom: PhantomData,
 63            req_tx: self.req_tx.clone(),
 64        };
 65        let result = scheduler(&mut scope);
 66        scope.wait();
 67        result
 68    }
 69}
 70
 71impl<'scope> Scope<'scope> {
 72    pub fn execute<F>(&mut self, callback: F)
 73    where
 74        F: FnOnce() + Send + 'scope,
 75    {
 76        // Transmute the callback's lifetime to be 'static. This is safe because in ::wait,
 77        // we block until all the callbacks have been called and dropped.
 78        let callback = unsafe {
 79            transmute::<Box<dyn FnOnce() + Send + 'scope>, Box<dyn FnOnce() + Send + 'static>>(
 80                Box::new(callback),
 81            )
 82        };
 83
 84        self.req_count += 1;
 85        self.req_tx
 86            .send(Request {
 87                callback,
 88                resp_tx: self.resp_tx.clone(),
 89            })
 90            .unwrap();
 91    }
 92
 93    fn wait(&self) {
 94        for _ in 0..self.req_count {
 95            self.resp_rx.recv().unwrap();
 96        }
 97    }
 98}
 99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use std::sync::{Arc, Mutex};
104
105    #[test]
106    fn test_execute() {
107        let pool = Pool::new(3, "test");
108
109        {
110            let vec = Mutex::new(Vec::new());
111            pool.scoped(|scope| {
112                for _ in 0..3 {
113                    scope.execute(|| {
114                        for i in 0..5 {
115                            vec.lock().unwrap().push(i);
116                        }
117                    });
118                }
119            });
120
121            let mut vec = vec.into_inner().unwrap();
122            vec.sort_unstable();
123            assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
124        }
125    }
126
127    #[test]
128    fn test_clone_send_and_execute() {
129        let pool = Pool::new(3, "test");
130
131        let mut threads = Vec::new();
132        for _ in 0..3 {
133            threads.push(thread::spawn({
134                let pool = pool.clone();
135                move || {
136                    let vec = Mutex::new(Vec::new());
137                    pool.scoped(|scope| {
138                        for _ in 0..3 {
139                            scope.execute(|| {
140                                for i in 0..5 {
141                                    vec.lock().unwrap().push(i);
142                                }
143                            });
144                        }
145                    });
146                    let mut vec = vec.into_inner().unwrap();
147                    vec.sort_unstable();
148                    assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
149                }
150            }));
151        }
152
153        for thread in threads {
154            thread.join().unwrap();
155        }
156    }
157
158    #[test]
159    fn test_share_and_execute() {
160        let pool = Arc::new(Pool::new(3, "test"));
161
162        let mut threads = Vec::new();
163        for _ in 0..3 {
164            threads.push(thread::spawn({
165                let pool = pool.clone();
166                move || {
167                    let vec = Mutex::new(Vec::new());
168                    pool.scoped(|scope| {
169                        for _ in 0..3 {
170                            scope.execute(|| {
171                                for i in 0..5 {
172                                    vec.lock().unwrap().push(i);
173                                }
174                            });
175                        }
176                    });
177                    let mut vec = vec.into_inner().unwrap();
178                    vec.sort_unstable();
179                    assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
180                }
181            }));
182        }
183
184        for thread in threads {
185            thread.join().unwrap();
186        }
187    }
188}