1use std::{
2 collections::VecDeque,
3 fmt,
4 iter::FusedIterator,
5 sync::{Arc, atomic::AtomicUsize},
6};
7
8use rand::{Rng, SeedableRng, rngs::SmallRng};
9
10use crate::Priority;
11
12struct PriorityQueues<T> {
13 high_priority: VecDeque<T>,
14 medium_priority: VecDeque<T>,
15 low_priority: VecDeque<T>,
16}
17
18impl<T> PriorityQueues<T> {
19 fn is_empty(&self) -> bool {
20 self.high_priority.is_empty()
21 && self.medium_priority.is_empty()
22 && self.low_priority.is_empty()
23 }
24}
25
26struct PriorityQueueState<T> {
27 queues: parking_lot::Mutex<PriorityQueues<T>>,
28 condvar: parking_lot::Condvar,
29 receiver_count: AtomicUsize,
30 sender_count: AtomicUsize,
31}
32
33impl<T> PriorityQueueState<T> {
34 fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
35 if self
36 .receiver_count
37 .load(std::sync::atomic::Ordering::Relaxed)
38 == 0
39 {
40 return Err(SendError(item));
41 }
42
43 let mut queues = self.queues.lock();
44 match priority {
45 Priority::RealtimeAudio => unreachable!(
46 "Realtime audio priority runs on a dedicated thread and is never queued"
47 ),
48 Priority::High => queues.high_priority.push_back(item),
49 Priority::Medium => queues.medium_priority.push_back(item),
50 Priority::Low => queues.low_priority.push_back(item),
51 };
52 self.condvar.notify_one();
53 Ok(())
54 }
55
56 fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
57 let mut queues = self.queues.lock();
58
59 let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
60 if queues.is_empty() && sender_count == 0 {
61 return Err(crate::queue::RecvError);
62 }
63
64 while queues.is_empty() {
65 self.condvar.wait(&mut queues);
66 }
67
68 Ok(queues)
69 }
70
71 fn try_recv<'a>(
72 &'a self,
73 ) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
74 let mut queues = self.queues.lock();
75
76 let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
77 if queues.is_empty() && sender_count == 0 {
78 return Err(crate::queue::RecvError);
79 }
80
81 if queues.is_empty() {
82 Ok(None)
83 } else {
84 Ok(Some(queues))
85 }
86 }
87}
88
89#[doc(hidden)]
90pub struct PriorityQueueSender<T> {
91 state: Arc<PriorityQueueState<T>>,
92}
93
94impl<T> PriorityQueueSender<T> {
95 fn new(state: Arc<PriorityQueueState<T>>) -> Self {
96 Self { state }
97 }
98
99 pub fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
100 self.state.send(priority, item)?;
101 Ok(())
102 }
103}
104
105impl<T> Drop for PriorityQueueSender<T> {
106 fn drop(&mut self) {
107 self.state
108 .sender_count
109 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
110 }
111}
112
113#[doc(hidden)]
114pub struct PriorityQueueReceiver<T> {
115 state: Arc<PriorityQueueState<T>>,
116 rand: SmallRng,
117 disconnected: bool,
118}
119
120impl<T> Clone for PriorityQueueReceiver<T> {
121 fn clone(&self) -> Self {
122 self.state
123 .receiver_count
124 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
125 Self {
126 state: Arc::clone(&self.state),
127 rand: SmallRng::seed_from_u64(0),
128 disconnected: self.disconnected,
129 }
130 }
131}
132
133#[doc(hidden)]
134pub struct SendError<T>(pub T);
135
136impl<T: fmt::Debug> fmt::Debug for SendError<T> {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 f.debug_tuple("SendError").field(&self.0).finish()
139 }
140}
141
142#[derive(Debug)]
143#[doc(hidden)]
144pub struct RecvError;
145
146#[allow(dead_code)]
147impl<T> PriorityQueueReceiver<T> {
148 pub fn new() -> (PriorityQueueSender<T>, Self) {
149 let state = PriorityQueueState {
150 queues: parking_lot::Mutex::new(PriorityQueues {
151 high_priority: VecDeque::new(),
152 medium_priority: VecDeque::new(),
153 low_priority: VecDeque::new(),
154 }),
155 condvar: parking_lot::Condvar::new(),
156 receiver_count: AtomicUsize::new(1),
157 sender_count: AtomicUsize::new(1),
158 };
159 let state = Arc::new(state);
160
161 let sender = PriorityQueueSender::new(Arc::clone(&state));
162
163 let receiver = PriorityQueueReceiver {
164 state,
165 rand: SmallRng::seed_from_u64(0),
166 disconnected: false,
167 };
168
169 (sender, receiver)
170 }
171
172 /// Tries to pop one element from the priority queue without blocking.
173 ///
174 /// This will early return if there are no elements in the queue.
175 ///
176 /// This method is best suited if you only intend to pop one element, for better performance
177 /// on large queues see [`Self::try_iter`]
178 ///
179 /// # Errors
180 ///
181 /// If the sender was dropped
182 pub fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
183 self.pop_inner(false)
184 }
185
186 /// Pops an element from the priority queue blocking if necessary.
187 ///
188 /// This method is best suited if you only intend to pop one element, for better performance
189 /// on large queues see [`Self::iter``]
190 ///
191 /// # Errors
192 ///
193 /// If the sender was dropped
194 pub fn pop(&mut self) -> Result<T, RecvError> {
195 self.pop_inner(true).map(|e| e.unwrap())
196 }
197
198 /// Returns an iterator over the elements of the queue
199 /// this iterator will end when all elements have been consumed and will not wait for new ones.
200 pub fn try_iter(self) -> TryIter<T> {
201 TryIter {
202 receiver: self,
203 ended: false,
204 }
205 }
206
207 /// Returns an iterator over the elements of the queue
208 /// this iterator will wait for new elements if the queue is empty.
209 pub fn iter(self) -> Iter<T> {
210 Iter(self)
211 }
212
213 #[inline(always)]
214 // algorithm is the loaded die from biased coin from
215 // https://www.keithschwarz.com/darts-dice-coins/
216 fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
217 use Priority as P;
218
219 let mut queues = if !block {
220 let Some(queues) = self.state.try_recv()? else {
221 return Ok(None);
222 };
223 queues
224 } else {
225 self.state.recv()?
226 };
227
228 let high = P::High.weight() * !queues.high_priority.is_empty() as u32;
229 let medium = P::Medium.weight() * !queues.medium_priority.is_empty() as u32;
230 let low = P::Low.weight() * !queues.low_priority.is_empty() as u32;
231 let mut mass = high + medium + low; //%
232
233 if !queues.high_priority.is_empty() {
234 let flip = self.rand.random_ratio(P::High.weight(), mass);
235 if flip {
236 return Ok(queues.high_priority.pop_front());
237 }
238 mass -= P::High.weight();
239 }
240
241 if !queues.medium_priority.is_empty() {
242 let flip = self.rand.random_ratio(P::Medium.weight(), mass);
243 if flip {
244 return Ok(queues.medium_priority.pop_front());
245 }
246 mass -= P::Medium.weight();
247 }
248
249 if !queues.low_priority.is_empty() {
250 let flip = self.rand.random_ratio(P::Low.weight(), mass);
251 if flip {
252 return Ok(queues.low_priority.pop_front());
253 }
254 }
255
256 Ok(None)
257 }
258}
259
260impl<T> Drop for PriorityQueueReceiver<T> {
261 fn drop(&mut self) {
262 self.state
263 .receiver_count
264 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
265 }
266}
267
268#[doc(hidden)]
269pub struct Iter<T>(PriorityQueueReceiver<T>);
270impl<T> Iterator for Iter<T> {
271 type Item = T;
272
273 fn next(&mut self) -> Option<Self::Item> {
274 self.0.pop().ok()
275 }
276}
277impl<T> FusedIterator for Iter<T> {}
278
279#[doc(hidden)]
280pub struct TryIter<T> {
281 receiver: PriorityQueueReceiver<T>,
282 ended: bool,
283}
284impl<T> Iterator for TryIter<T> {
285 type Item = Result<T, RecvError>;
286
287 fn next(&mut self) -> Option<Self::Item> {
288 if self.ended {
289 return None;
290 }
291
292 let res = self.receiver.try_pop();
293 self.ended = res.is_err();
294
295 res.transpose()
296 }
297}
298impl<T> FusedIterator for TryIter<T> {}
299
300#[cfg(test)]
301mod tests {
302 use collections::HashSet;
303
304 use super::*;
305
306 #[test]
307 fn all_tasks_get_yielded() {
308 let (tx, mut rx) = PriorityQueueReceiver::new();
309 tx.send(Priority::Medium, 20).unwrap();
310 tx.send(Priority::High, 30).unwrap();
311 tx.send(Priority::Low, 10).unwrap();
312 tx.send(Priority::Medium, 21).unwrap();
313 tx.send(Priority::High, 31).unwrap();
314
315 drop(tx);
316
317 assert_eq!(
318 rx.iter().collect::<HashSet<_>>(),
319 [30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
320 )
321 }
322
323 #[test]
324 fn new_high_prio_task_get_scheduled_quickly() {
325 let (tx, mut rx) = PriorityQueueReceiver::new();
326 for _ in 0..100 {
327 tx.send(Priority::Low, 1).unwrap();
328 }
329
330 assert_eq!(rx.pop().unwrap(), 1);
331 tx.send(Priority::High, 3).unwrap();
332 assert_eq!(rx.pop().unwrap(), 3);
333 assert_eq!(rx.pop().unwrap(), 1);
334 }
335}