1use crossbeam_channel as chan;
2use std::{marker::PhantomData, mem::transmute, thread};
3
4#[derive(Clone)]
5pub struct Pool {
6 req_tx: chan::Sender<Request>,
7 thread_count: usize,
8}
9
10pub struct Scope<'a> {
11 req_count: usize,
12 req_tx: chan::Sender<Request>,
13 resp_tx: chan::Sender<()>,
14 resp_rx: chan::Receiver<()>,
15 phantom: PhantomData<&'a ()>,
16}
17
18struct Request {
19 callback: Box<dyn FnOnce() + Send + 'static>,
20 resp_tx: chan::Sender<()>,
21}
22
23impl Pool {
24 pub fn new(thread_count: usize, name: impl AsRef<str>) -> Self {
25 let (req_tx, req_rx) = chan::unbounded();
26 for i in 0..thread_count {
27 thread::Builder::new()
28 .name(format!("scoped_pool {} {}", name.as_ref(), i))
29 .spawn({
30 let req_rx = req_rx.clone();
31 move || loop {
32 match req_rx.recv() {
33 Err(_) => break,
34 Ok(Request { callback, resp_tx }) => {
35 callback();
36 resp_tx.send(()).ok();
37 }
38 }
39 }
40 })
41 .expect("scoped_pool: failed to spawn thread");
42 }
43 Self {
44 req_tx,
45 thread_count,
46 }
47 }
48
49 pub fn thread_count(&self) -> usize {
50 self.thread_count
51 }
52
53 pub fn scoped<'scope, F, R>(&self, scheduler: F) -> R
54 where
55 F: FnOnce(&mut Scope<'scope>) -> R,
56 {
57 let (resp_tx, resp_rx) = chan::bounded(1);
58 let mut scope = Scope {
59 resp_tx,
60 resp_rx,
61 req_count: 0,
62 phantom: PhantomData,
63 req_tx: self.req_tx.clone(),
64 };
65 let result = scheduler(&mut scope);
66 scope.wait();
67 result
68 }
69}
70
71impl<'scope> Scope<'scope> {
72 pub fn execute<F>(&mut self, callback: F)
73 where
74 F: FnOnce() + Send + 'scope,
75 {
76 // Transmute the callback's lifetime to be 'static. This is safe because in ::wait,
77 // we block until all the callbacks have been called and dropped.
78 let callback = unsafe {
79 transmute::<Box<dyn FnOnce() + Send + 'scope>, Box<dyn FnOnce() + Send + 'static>>(
80 Box::new(callback),
81 )
82 };
83
84 self.req_count += 1;
85 self.req_tx
86 .send(Request {
87 callback,
88 resp_tx: self.resp_tx.clone(),
89 })
90 .unwrap();
91 }
92
93 fn wait(&self) {
94 for _ in 0..self.req_count {
95 self.resp_rx.recv().unwrap();
96 }
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use std::sync::{Arc, Mutex};
104
105 #[test]
106 fn test_execute() {
107 let pool = Pool::new(3, "test");
108
109 {
110 let vec = Mutex::new(Vec::new());
111 pool.scoped(|scope| {
112 for _ in 0..3 {
113 scope.execute(|| {
114 for i in 0..5 {
115 vec.lock().unwrap().push(i);
116 }
117 });
118 }
119 });
120
121 let mut vec = vec.into_inner().unwrap();
122 vec.sort_unstable();
123 assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
124 }
125 }
126
127 #[test]
128 fn test_clone_send_and_execute() {
129 let pool = Pool::new(3, "test");
130
131 let mut threads = Vec::new();
132 for _ in 0..3 {
133 threads.push(thread::spawn({
134 let pool = pool.clone();
135 move || {
136 let vec = Mutex::new(Vec::new());
137 pool.scoped(|scope| {
138 for _ in 0..3 {
139 scope.execute(|| {
140 for i in 0..5 {
141 vec.lock().unwrap().push(i);
142 }
143 });
144 }
145 });
146 let mut vec = vec.into_inner().unwrap();
147 vec.sort_unstable();
148 assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
149 }
150 }));
151 }
152
153 for thread in threads {
154 thread.join().unwrap();
155 }
156 }
157
158 #[test]
159 fn test_share_and_execute() {
160 let pool = Arc::new(Pool::new(3, "test"));
161
162 let mut threads = Vec::new();
163 for _ in 0..3 {
164 threads.push(thread::spawn({
165 let pool = pool.clone();
166 move || {
167 let vec = Mutex::new(Vec::new());
168 pool.scoped(|scope| {
169 for _ in 0..3 {
170 scope.execute(|| {
171 for i in 0..5 {
172 vec.lock().unwrap().push(i);
173 }
174 });
175 }
176 });
177 let mut vec = vec.into_inner().unwrap();
178 vec.sort_unstable();
179 assert_eq!(vec, [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
180 }
181 }));
182 }
183
184 for thread in threads {
185 thread.join().unwrap();
186 }
187 }
188}