1use std::{
2 fmt,
3 iter::FusedIterator,
4 sync::{Arc, atomic::AtomicUsize},
5};
6
7use rand::{Rng, SeedableRng, rngs::SmallRng};
8
9use crate::Priority;
10
11struct PriorityQueues<T> {
12 high_priority: Vec<T>,
13 medium_priority: Vec<T>,
14 low_priority: Vec<T>,
15}
16
17impl<T> PriorityQueues<T> {
18 fn is_empty(&self) -> bool {
19 self.high_priority.is_empty()
20 && self.medium_priority.is_empty()
21 && self.low_priority.is_empty()
22 }
23}
24
25struct PriorityQueueState<T> {
26 queues: parking_lot::Mutex<PriorityQueues<T>>,
27 condvar: parking_lot::Condvar,
28 receiver_count: AtomicUsize,
29 sender_count: AtomicUsize,
30}
31
32impl<T> PriorityQueueState<T> {
33 fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
34 if self
35 .receiver_count
36 .load(std::sync::atomic::Ordering::Relaxed)
37 == 0
38 {
39 return Err(SendError(item));
40 }
41
42 let mut queues = self.queues.lock();
43 match priority {
44 Priority::Realtime(_) => unreachable!(),
45 Priority::High => queues.high_priority.push(item),
46 Priority::Medium => queues.medium_priority.push(item),
47 Priority::Low => queues.low_priority.push(item),
48 };
49 self.condvar.notify_one();
50 Ok(())
51 }
52
53 fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
54 let mut queues = self.queues.lock();
55
56 let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
57 if queues.is_empty() && sender_count == 0 {
58 return Err(crate::queue::RecvError);
59 }
60
61 while queues.is_empty() {
62 self.condvar.wait(&mut queues);
63 }
64
65 Ok(queues)
66 }
67
68 fn try_recv<'a>(
69 &'a self,
70 ) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
71 let mut queues = self.queues.lock();
72
73 let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
74 if queues.is_empty() && sender_count == 0 {
75 return Err(crate::queue::RecvError);
76 }
77
78 if queues.is_empty() {
79 Ok(None)
80 } else {
81 Ok(Some(queues))
82 }
83 }
84}
85
86pub(crate) struct PriorityQueueSender<T> {
87 state: Arc<PriorityQueueState<T>>,
88}
89
90impl<T> PriorityQueueSender<T> {
91 fn new(state: Arc<PriorityQueueState<T>>) -> Self {
92 Self { state }
93 }
94
95 pub(crate) fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
96 self.state.send(priority, item)?;
97 Ok(())
98 }
99}
100
101impl<T> Drop for PriorityQueueSender<T> {
102 fn drop(&mut self) {
103 self.state
104 .sender_count
105 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
106 }
107}
108
109pub(crate) struct PriorityQueueReceiver<T> {
110 state: Arc<PriorityQueueState<T>>,
111 rand: SmallRng,
112 disconnected: bool,
113}
114
115impl<T> Clone for PriorityQueueReceiver<T> {
116 fn clone(&self) -> Self {
117 self.state
118 .receiver_count
119 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
120 Self {
121 state: Arc::clone(&self.state),
122 rand: SmallRng::seed_from_u64(0),
123 disconnected: self.disconnected,
124 }
125 }
126}
127
128pub(crate) struct SendError<T>(T);
129
130impl<T: fmt::Debug> fmt::Debug for SendError<T> {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.debug_tuple("SendError").field(&self.0).finish()
133 }
134}
135
136#[derive(Debug)]
137pub(crate) struct RecvError;
138
139#[allow(dead_code)]
140impl<T> PriorityQueueReceiver<T> {
141 pub(crate) fn new() -> (PriorityQueueSender<T>, Self) {
142 let state = PriorityQueueState {
143 queues: parking_lot::Mutex::new(PriorityQueues {
144 high_priority: Vec::new(),
145 medium_priority: Vec::new(),
146 low_priority: Vec::new(),
147 }),
148 condvar: parking_lot::Condvar::new(),
149 receiver_count: AtomicUsize::new(1),
150 sender_count: AtomicUsize::new(1),
151 };
152 let state = Arc::new(state);
153
154 let sender = PriorityQueueSender::new(Arc::clone(&state));
155
156 let receiver = PriorityQueueReceiver {
157 state,
158 rand: SmallRng::seed_from_u64(0),
159 disconnected: false,
160 };
161
162 (sender, receiver)
163 }
164
165 /// Tries to pop one element from the priority queue without blocking.
166 ///
167 /// This will early return if there are no elements in the queue.
168 ///
169 /// This method is best suited if you only intend to pop one element, for better performance
170 /// on large queues see [`Self::try_iter`]
171 ///
172 /// # Errors
173 ///
174 /// If the sender was dropped
175 pub(crate) fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
176 self.pop_inner(false)
177 }
178
179 /// Pops an element from the priority queue blocking if necessary.
180 ///
181 /// This method is best suited if you only intend to pop one element, for better performance
182 /// on large queues see [`Self::iter``]
183 ///
184 /// # Errors
185 ///
186 /// If the sender was dropped
187 pub(crate) fn pop(&mut self) -> Result<T, RecvError> {
188 self.pop_inner(true).map(|e| e.unwrap())
189 }
190
191 /// Returns an iterator over the elements of the queue
192 /// this iterator will end when all elements have been consumed and will not wait for new ones.
193 pub(crate) fn try_iter(self) -> TryIter<T> {
194 TryIter {
195 receiver: self,
196 ended: false,
197 }
198 }
199
200 /// Returns an iterator over the elements of the queue
201 /// this iterator will wait for new elements if the queue is empty.
202 pub(crate) fn iter(self) -> Iter<T> {
203 Iter(self)
204 }
205
206 #[inline(always)]
207 // algorithm is the loaded die from biased coin from
208 // https://www.keithschwarz.com/darts-dice-coins/
209 fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
210 use Priority as P;
211
212 let mut queues = if !block {
213 let Some(queues) = self.state.try_recv()? else {
214 return Ok(None);
215 };
216 queues
217 } else {
218 self.state.recv()?
219 };
220
221 let high = P::High.probability() * !queues.high_priority.is_empty() as u32;
222 let medium = P::Medium.probability() * !queues.medium_priority.is_empty() as u32;
223 let low = P::Low.probability() * !queues.low_priority.is_empty() as u32;
224 let mut mass = high + medium + low; //%
225
226 if !queues.high_priority.is_empty() {
227 let flip = self.rand.random_ratio(P::High.probability(), mass);
228 if flip {
229 return Ok(queues.high_priority.pop());
230 }
231 mass -= P::High.probability();
232 }
233
234 if !queues.medium_priority.is_empty() {
235 let flip = self.rand.random_ratio(P::Medium.probability(), mass);
236 if flip {
237 return Ok(queues.medium_priority.pop());
238 }
239 mass -= P::Medium.probability();
240 }
241
242 if !queues.low_priority.is_empty() {
243 let flip = self.rand.random_ratio(P::Low.probability(), mass);
244 if flip {
245 return Ok(queues.low_priority.pop());
246 }
247 }
248
249 Ok(None)
250 }
251}
252
253impl<T> Drop for PriorityQueueReceiver<T> {
254 fn drop(&mut self) {
255 self.state
256 .receiver_count
257 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
258 }
259}
260
261/// If None is returned the sender disconnected
262pub(crate) struct Iter<T>(PriorityQueueReceiver<T>);
263impl<T> Iterator for Iter<T> {
264 type Item = T;
265
266 fn next(&mut self) -> Option<Self::Item> {
267 self.0.pop().ok()
268 }
269}
270impl<T> FusedIterator for Iter<T> {}
271
272/// If None is returned there are no more elements in the queue
273pub(crate) struct TryIter<T> {
274 receiver: PriorityQueueReceiver<T>,
275 ended: bool,
276}
277impl<T> Iterator for TryIter<T> {
278 type Item = Result<T, RecvError>;
279
280 fn next(&mut self) -> Option<Self::Item> {
281 if self.ended {
282 return None;
283 }
284
285 let res = self.receiver.try_pop();
286 self.ended = res.is_err();
287
288 res.transpose()
289 }
290}
291impl<T> FusedIterator for TryIter<T> {}
292
293#[cfg(test)]
294mod tests {
295 use collections::HashSet;
296
297 use super::*;
298
299 #[test]
300 fn all_tasks_get_yielded() {
301 let (tx, mut rx) = PriorityQueueReceiver::new();
302 tx.send(Priority::Medium, 20).unwrap();
303 tx.send(Priority::High, 30).unwrap();
304 tx.send(Priority::Low, 10).unwrap();
305 tx.send(Priority::Medium, 21).unwrap();
306 tx.send(Priority::High, 31).unwrap();
307
308 drop(tx);
309
310 assert_eq!(
311 rx.iter().collect::<HashSet<_>>(),
312 [30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
313 )
314 }
315
316 #[test]
317 fn new_high_prio_task_get_scheduled_quickly() {
318 let (tx, mut rx) = PriorityQueueReceiver::new();
319 for _ in 0..100 {
320 tx.send(Priority::Low, 1).unwrap();
321 }
322
323 assert_eq!(rx.pop().unwrap(), 1);
324 tx.send(Priority::High, 3).unwrap();
325 assert_eq!(rx.pop().unwrap(), 3);
326 assert_eq!(rx.pop().unwrap(), 1);
327 }
328}