1use std::{
2 collections::VecDeque,
3 fmt,
4 iter::FusedIterator,
5 sync::{Arc, atomic::AtomicUsize},
6};
7
8use rand::{Rng, SeedableRng, rngs::SmallRng};
9
10use crate::Priority;
11
12struct PriorityQueues<T> {
13 high_priority: VecDeque<T>,
14 medium_priority: VecDeque<T>,
15 low_priority: VecDeque<T>,
16}
17
18impl<T> PriorityQueues<T> {
19 fn is_empty(&self) -> bool {
20 self.high_priority.is_empty()
21 && self.medium_priority.is_empty()
22 && self.low_priority.is_empty()
23 }
24}
25
26struct PriorityQueueState<T> {
27 queues: parking_lot::Mutex<PriorityQueues<T>>,
28 condvar: parking_lot::Condvar,
29 receiver_count: AtomicUsize,
30 sender_count: AtomicUsize,
31}
32
33impl<T> PriorityQueueState<T> {
34 fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
35 if self
36 .receiver_count
37 .load(std::sync::atomic::Ordering::Relaxed)
38 == 0
39 {
40 return Err(SendError(item));
41 }
42
43 let mut queues = self.queues.lock();
44 match priority {
45 Priority::Realtime(_) => unreachable!(),
46 Priority::High => queues.high_priority.push_back(item),
47 Priority::Medium => queues.medium_priority.push_back(item),
48 Priority::Low => queues.low_priority.push_back(item),
49 };
50 self.condvar.notify_one();
51 Ok(())
52 }
53
54 fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
55 let mut queues = self.queues.lock();
56
57 let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
58 if queues.is_empty() && sender_count == 0 {
59 return Err(crate::queue::RecvError);
60 }
61
62 while queues.is_empty() {
63 self.condvar.wait(&mut queues);
64 }
65
66 Ok(queues)
67 }
68
69 fn try_recv<'a>(
70 &'a self,
71 ) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
72 let mut queues = self.queues.lock();
73
74 let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
75 if queues.is_empty() && sender_count == 0 {
76 return Err(crate::queue::RecvError);
77 }
78
79 if queues.is_empty() {
80 Ok(None)
81 } else {
82 Ok(Some(queues))
83 }
84 }
85}
86
87pub(crate) struct PriorityQueueSender<T> {
88 state: Arc<PriorityQueueState<T>>,
89}
90
91impl<T> PriorityQueueSender<T> {
92 fn new(state: Arc<PriorityQueueState<T>>) -> Self {
93 Self { state }
94 }
95
96 pub(crate) fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
97 self.state.send(priority, item)?;
98 Ok(())
99 }
100}
101
102impl<T> Drop for PriorityQueueSender<T> {
103 fn drop(&mut self) {
104 self.state
105 .sender_count
106 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
107 }
108}
109
110pub(crate) struct PriorityQueueReceiver<T> {
111 state: Arc<PriorityQueueState<T>>,
112 rand: SmallRng,
113 disconnected: bool,
114}
115
116impl<T> Clone for PriorityQueueReceiver<T> {
117 fn clone(&self) -> Self {
118 self.state
119 .receiver_count
120 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
121 Self {
122 state: Arc::clone(&self.state),
123 rand: SmallRng::seed_from_u64(0),
124 disconnected: self.disconnected,
125 }
126 }
127}
128
129pub(crate) struct SendError<T>(T);
130
131impl<T: fmt::Debug> fmt::Debug for SendError<T> {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 f.debug_tuple("SendError").field(&self.0).finish()
134 }
135}
136
137#[derive(Debug)]
138pub(crate) struct RecvError;
139
140#[allow(dead_code)]
141impl<T> PriorityQueueReceiver<T> {
142 pub(crate) fn new() -> (PriorityQueueSender<T>, Self) {
143 let state = PriorityQueueState {
144 queues: parking_lot::Mutex::new(PriorityQueues {
145 high_priority: VecDeque::new(),
146 medium_priority: VecDeque::new(),
147 low_priority: VecDeque::new(),
148 }),
149 condvar: parking_lot::Condvar::new(),
150 receiver_count: AtomicUsize::new(1),
151 sender_count: AtomicUsize::new(1),
152 };
153 let state = Arc::new(state);
154
155 let sender = PriorityQueueSender::new(Arc::clone(&state));
156
157 let receiver = PriorityQueueReceiver {
158 state,
159 rand: SmallRng::seed_from_u64(0),
160 disconnected: false,
161 };
162
163 (sender, receiver)
164 }
165
166 /// Tries to pop one element from the priority queue without blocking.
167 ///
168 /// This will early return if there are no elements in the queue.
169 ///
170 /// This method is best suited if you only intend to pop one element, for better performance
171 /// on large queues see [`Self::try_iter`]
172 ///
173 /// # Errors
174 ///
175 /// If the sender was dropped
176 pub(crate) fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
177 self.pop_inner(false)
178 }
179
180 /// Pops an element from the priority queue blocking if necessary.
181 ///
182 /// This method is best suited if you only intend to pop one element, for better performance
183 /// on large queues see [`Self::iter``]
184 ///
185 /// # Errors
186 ///
187 /// If the sender was dropped
188 pub(crate) fn pop(&mut self) -> Result<T, RecvError> {
189 self.pop_inner(true).map(|e| e.unwrap())
190 }
191
192 /// Returns an iterator over the elements of the queue
193 /// this iterator will end when all elements have been consumed and will not wait for new ones.
194 pub(crate) fn try_iter(self) -> TryIter<T> {
195 TryIter {
196 receiver: self,
197 ended: false,
198 }
199 }
200
201 /// Returns an iterator over the elements of the queue
202 /// this iterator will wait for new elements if the queue is empty.
203 pub(crate) fn iter(self) -> Iter<T> {
204 Iter(self)
205 }
206
207 #[inline(always)]
208 // algorithm is the loaded die from biased coin from
209 // https://www.keithschwarz.com/darts-dice-coins/
210 fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
211 use Priority as P;
212
213 let mut queues = if !block {
214 let Some(queues) = self.state.try_recv()? else {
215 return Ok(None);
216 };
217 queues
218 } else {
219 self.state.recv()?
220 };
221
222 let high = P::High.probability() * !queues.high_priority.is_empty() as u32;
223 let medium = P::Medium.probability() * !queues.medium_priority.is_empty() as u32;
224 let low = P::Low.probability() * !queues.low_priority.is_empty() as u32;
225 let mut mass = high + medium + low; //%
226
227 if !queues.high_priority.is_empty() {
228 let flip = self.rand.random_ratio(P::High.probability(), mass);
229 if flip {
230 return Ok(queues.high_priority.pop_front());
231 }
232 mass -= P::High.probability();
233 }
234
235 if !queues.medium_priority.is_empty() {
236 let flip = self.rand.random_ratio(P::Medium.probability(), mass);
237 if flip {
238 return Ok(queues.medium_priority.pop_front());
239 }
240 mass -= P::Medium.probability();
241 }
242
243 if !queues.low_priority.is_empty() {
244 let flip = self.rand.random_ratio(P::Low.probability(), mass);
245 if flip {
246 return Ok(queues.low_priority.pop_front());
247 }
248 }
249
250 Ok(None)
251 }
252}
253
254impl<T> Drop for PriorityQueueReceiver<T> {
255 fn drop(&mut self) {
256 self.state
257 .receiver_count
258 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
259 }
260}
261
262/// If None is returned the sender disconnected
263pub(crate) struct Iter<T>(PriorityQueueReceiver<T>);
264impl<T> Iterator for Iter<T> {
265 type Item = T;
266
267 fn next(&mut self) -> Option<Self::Item> {
268 self.0.pop().ok()
269 }
270}
271impl<T> FusedIterator for Iter<T> {}
272
273/// If None is returned there are no more elements in the queue
274pub(crate) struct TryIter<T> {
275 receiver: PriorityQueueReceiver<T>,
276 ended: bool,
277}
278impl<T> Iterator for TryIter<T> {
279 type Item = Result<T, RecvError>;
280
281 fn next(&mut self) -> Option<Self::Item> {
282 if self.ended {
283 return None;
284 }
285
286 let res = self.receiver.try_pop();
287 self.ended = res.is_err();
288
289 res.transpose()
290 }
291}
292impl<T> FusedIterator for TryIter<T> {}
293
294#[cfg(test)]
295mod tests {
296 use collections::HashSet;
297
298 use super::*;
299
300 #[test]
301 fn all_tasks_get_yielded() {
302 let (tx, mut rx) = PriorityQueueReceiver::new();
303 tx.send(Priority::Medium, 20).unwrap();
304 tx.send(Priority::High, 30).unwrap();
305 tx.send(Priority::Low, 10).unwrap();
306 tx.send(Priority::Medium, 21).unwrap();
307 tx.send(Priority::High, 31).unwrap();
308
309 drop(tx);
310
311 assert_eq!(
312 rx.iter().collect::<HashSet<_>>(),
313 [30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
314 )
315 }
316
317 #[test]
318 fn new_high_prio_task_get_scheduled_quickly() {
319 let (tx, mut rx) = PriorityQueueReceiver::new();
320 for _ in 0..100 {
321 tx.send(Priority::Low, 1).unwrap();
322 }
323
324 assert_eq!(rx.pop().unwrap(), 1);
325 tx.send(Priority::High, 3).unwrap();
326 assert_eq!(rx.pop().unwrap(), 3);
327 assert_eq!(rx.pop().unwrap(), 1);
328 }
329}