1use std::{
2 collections::VecDeque,
3 fmt,
4 iter::FusedIterator,
5 sync::{Arc, atomic::AtomicUsize},
6};
7
8use rand::{Rng, SeedableRng, rngs::SmallRng};
9
10use crate::Priority;
11
12struct PriorityQueues<T> {
13 high_priority: VecDeque<T>,
14 medium_priority: VecDeque<T>,
15 low_priority: VecDeque<T>,
16}
17
18impl<T> PriorityQueues<T> {
19 fn is_empty(&self) -> bool {
20 self.high_priority.is_empty()
21 && self.medium_priority.is_empty()
22 && self.low_priority.is_empty()
23 }
24}
25
26struct PriorityQueueState<T> {
27 queues: parking_lot::Mutex<PriorityQueues<T>>,
28 condvar: parking_lot::Condvar,
29 receiver_count: AtomicUsize,
30 sender_count: AtomicUsize,
31}
32
33impl<T> PriorityQueueState<T> {
34 fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
35 if self
36 .receiver_count
37 .load(std::sync::atomic::Ordering::Relaxed)
38 == 0
39 {
40 return Err(SendError(item));
41 }
42
43 let mut queues = self.queues.lock();
44 match priority {
45 Priority::RealtimeAudio => unreachable!(
46 "Realtime audio priority runs on a dedicated thread and is never queued"
47 ),
48 Priority::High => queues.high_priority.push_back(item),
49 Priority::Medium => queues.medium_priority.push_back(item),
50 Priority::Low => queues.low_priority.push_back(item),
51 };
52 self.condvar.notify_one();
53 Ok(())
54 }
55
56 fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
57 let mut queues = self.queues.lock();
58
59 let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
60 if queues.is_empty() && sender_count == 0 {
61 return Err(crate::queue::RecvError);
62 }
63
64 while queues.is_empty() {
65 self.condvar.wait(&mut queues);
66 }
67
68 Ok(queues)
69 }
70
71 fn try_recv<'a>(
72 &'a self,
73 ) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
74 let mut queues = self.queues.lock();
75
76 let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
77 if queues.is_empty() && sender_count == 0 {
78 return Err(crate::queue::RecvError);
79 }
80
81 if queues.is_empty() {
82 Ok(None)
83 } else {
84 Ok(Some(queues))
85 }
86 }
87}
88
89pub(crate) struct PriorityQueueSender<T> {
90 state: Arc<PriorityQueueState<T>>,
91}
92
93impl<T> PriorityQueueSender<T> {
94 fn new(state: Arc<PriorityQueueState<T>>) -> Self {
95 Self { state }
96 }
97
98 pub(crate) fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
99 self.state.send(priority, item)?;
100 Ok(())
101 }
102}
103
104impl<T> Drop for PriorityQueueSender<T> {
105 fn drop(&mut self) {
106 self.state
107 .sender_count
108 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
109 }
110}
111
112pub(crate) struct PriorityQueueReceiver<T> {
113 state: Arc<PriorityQueueState<T>>,
114 rand: SmallRng,
115 disconnected: bool,
116}
117
118impl<T> Clone for PriorityQueueReceiver<T> {
119 fn clone(&self) -> Self {
120 self.state
121 .receiver_count
122 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
123 Self {
124 state: Arc::clone(&self.state),
125 rand: SmallRng::seed_from_u64(0),
126 disconnected: self.disconnected,
127 }
128 }
129}
130
131pub(crate) struct SendError<T>(T);
132
133impl<T: fmt::Debug> fmt::Debug for SendError<T> {
134 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 f.debug_tuple("SendError").field(&self.0).finish()
136 }
137}
138
139#[derive(Debug)]
140pub(crate) struct RecvError;
141
142#[allow(dead_code)]
143impl<T> PriorityQueueReceiver<T> {
144 pub(crate) fn new() -> (PriorityQueueSender<T>, Self) {
145 let state = PriorityQueueState {
146 queues: parking_lot::Mutex::new(PriorityQueues {
147 high_priority: VecDeque::new(),
148 medium_priority: VecDeque::new(),
149 low_priority: VecDeque::new(),
150 }),
151 condvar: parking_lot::Condvar::new(),
152 receiver_count: AtomicUsize::new(1),
153 sender_count: AtomicUsize::new(1),
154 };
155 let state = Arc::new(state);
156
157 let sender = PriorityQueueSender::new(Arc::clone(&state));
158
159 let receiver = PriorityQueueReceiver {
160 state,
161 rand: SmallRng::seed_from_u64(0),
162 disconnected: false,
163 };
164
165 (sender, receiver)
166 }
167
168 /// Tries to pop one element from the priority queue without blocking.
169 ///
170 /// This will early return if there are no elements in the queue.
171 ///
172 /// This method is best suited if you only intend to pop one element, for better performance
173 /// on large queues see [`Self::try_iter`]
174 ///
175 /// # Errors
176 ///
177 /// If the sender was dropped
178 pub(crate) fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
179 self.pop_inner(false)
180 }
181
182 /// Pops an element from the priority queue blocking if necessary.
183 ///
184 /// This method is best suited if you only intend to pop one element, for better performance
185 /// on large queues see [`Self::iter``]
186 ///
187 /// # Errors
188 ///
189 /// If the sender was dropped
190 pub(crate) fn pop(&mut self) -> Result<T, RecvError> {
191 self.pop_inner(true).map(|e| e.unwrap())
192 }
193
194 /// Returns an iterator over the elements of the queue
195 /// this iterator will end when all elements have been consumed and will not wait for new ones.
196 pub(crate) fn try_iter(self) -> TryIter<T> {
197 TryIter {
198 receiver: self,
199 ended: false,
200 }
201 }
202
203 /// Returns an iterator over the elements of the queue
204 /// this iterator will wait for new elements if the queue is empty.
205 pub(crate) fn iter(self) -> Iter<T> {
206 Iter(self)
207 }
208
209 #[inline(always)]
210 // algorithm is the loaded die from biased coin from
211 // https://www.keithschwarz.com/darts-dice-coins/
212 fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
213 use Priority as P;
214
215 let mut queues = if !block {
216 let Some(queues) = self.state.try_recv()? else {
217 return Ok(None);
218 };
219 queues
220 } else {
221 self.state.recv()?
222 };
223
224 let high = P::High.weight() * !queues.high_priority.is_empty() as u32;
225 let medium = P::Medium.weight() * !queues.medium_priority.is_empty() as u32;
226 let low = P::Low.weight() * !queues.low_priority.is_empty() as u32;
227 let mut mass = high + medium + low; //%
228
229 if !queues.high_priority.is_empty() {
230 let flip = self.rand.random_ratio(P::High.weight(), mass);
231 if flip {
232 return Ok(queues.high_priority.pop_front());
233 }
234 mass -= P::High.weight();
235 }
236
237 if !queues.medium_priority.is_empty() {
238 let flip = self.rand.random_ratio(P::Medium.weight(), mass);
239 if flip {
240 return Ok(queues.medium_priority.pop_front());
241 }
242 mass -= P::Medium.weight();
243 }
244
245 if !queues.low_priority.is_empty() {
246 let flip = self.rand.random_ratio(P::Low.weight(), mass);
247 if flip {
248 return Ok(queues.low_priority.pop_front());
249 }
250 }
251
252 Ok(None)
253 }
254}
255
256impl<T> Drop for PriorityQueueReceiver<T> {
257 fn drop(&mut self) {
258 self.state
259 .receiver_count
260 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
261 }
262}
263
264/// If None is returned the sender disconnected
265pub(crate) struct Iter<T>(PriorityQueueReceiver<T>);
266impl<T> Iterator for Iter<T> {
267 type Item = T;
268
269 fn next(&mut self) -> Option<Self::Item> {
270 self.0.pop().ok()
271 }
272}
273impl<T> FusedIterator for Iter<T> {}
274
275/// If None is returned there are no more elements in the queue
276pub(crate) struct TryIter<T> {
277 receiver: PriorityQueueReceiver<T>,
278 ended: bool,
279}
280impl<T> Iterator for TryIter<T> {
281 type Item = Result<T, RecvError>;
282
283 fn next(&mut self) -> Option<Self::Item> {
284 if self.ended {
285 return None;
286 }
287
288 let res = self.receiver.try_pop();
289 self.ended = res.is_err();
290
291 res.transpose()
292 }
293}
294impl<T> FusedIterator for TryIter<T> {}
295
296#[cfg(test)]
297mod tests {
298 use collections::HashSet;
299
300 use super::*;
301
302 #[test]
303 fn all_tasks_get_yielded() {
304 let (tx, mut rx) = PriorityQueueReceiver::new();
305 tx.send(Priority::Medium, 20).unwrap();
306 tx.send(Priority::High, 30).unwrap();
307 tx.send(Priority::Low, 10).unwrap();
308 tx.send(Priority::Medium, 21).unwrap();
309 tx.send(Priority::High, 31).unwrap();
310
311 drop(tx);
312
313 assert_eq!(
314 rx.iter().collect::<HashSet<_>>(),
315 [30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
316 )
317 }
318
319 #[test]
320 fn new_high_prio_task_get_scheduled_quickly() {
321 let (tx, mut rx) = PriorityQueueReceiver::new();
322 for _ in 0..100 {
323 tx.send(Priority::Low, 1).unwrap();
324 }
325
326 assert_eq!(rx.pop().unwrap(), 1);
327 tx.send(Priority::High, 3).unwrap();
328 assert_eq!(rx.pop().unwrap(), 3);
329 assert_eq!(rx.pop().unwrap(), 1);
330 }
331}