1use std::{
2 collections::VecDeque,
3 fmt,
4 iter::FusedIterator,
5 sync::{Arc, atomic::AtomicUsize},
6};
7
8use rand::{Rng, SeedableRng, rngs::SmallRng};
9
10use crate::Priority;
11
12struct PriorityQueues<T> {
13 high_priority: VecDeque<T>,
14 medium_priority: VecDeque<T>,
15 low_priority: VecDeque<T>,
16}
17
18impl<T> PriorityQueues<T> {
19 fn is_empty(&self) -> bool {
20 self.high_priority.is_empty()
21 && self.medium_priority.is_empty()
22 && self.low_priority.is_empty()
23 }
24}
25
26struct PriorityQueueState<T> {
27 queues: parking_lot::Mutex<PriorityQueues<T>>,
28 condvar: parking_lot::Condvar,
29 receiver_count: AtomicUsize,
30 sender_count: AtomicUsize,
31}
32
33impl<T> PriorityQueueState<T> {
34 fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
35 if self
36 .receiver_count
37 .load(std::sync::atomic::Ordering::Relaxed)
38 == 0
39 {
40 return Err(SendError(item));
41 }
42
43 let mut queues = self.queues.lock();
44 Self::push(&mut queues, priority, item);
45 self.condvar.notify_one();
46 Ok(())
47 }
48
49 fn spin_send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
50 if self
51 .receiver_count
52 .load(std::sync::atomic::Ordering::Relaxed)
53 == 0
54 {
55 return Err(SendError(item));
56 }
57
58 let mut queues = loop {
59 if let Some(guard) = self.queues.try_lock() {
60 break guard;
61 }
62 std::hint::spin_loop();
63 };
64 Self::push(&mut queues, priority, item);
65 self.condvar.notify_one();
66 Ok(())
67 }
68
69 fn push(queues: &mut PriorityQueues<T>, priority: Priority, item: T) {
70 match priority {
71 Priority::RealtimeAudio => unreachable!(
72 "Realtime audio priority runs on a dedicated thread and is never queued"
73 ),
74 Priority::High => queues.high_priority.push_back(item),
75 Priority::Medium => queues.medium_priority.push_back(item),
76 Priority::Low => queues.low_priority.push_back(item),
77 };
78 }
79
80 fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
81 let mut queues = self.queues.lock();
82
83 let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
84 if queues.is_empty() && sender_count == 0 {
85 return Err(crate::queue::RecvError);
86 }
87
88 while queues.is_empty() {
89 self.condvar.wait(&mut queues);
90 }
91
92 Ok(queues)
93 }
94
95 fn try_recv<'a>(
96 &'a self,
97 ) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
98 let mut queues = self.queues.lock();
99
100 let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
101 if queues.is_empty() && sender_count == 0 {
102 return Err(crate::queue::RecvError);
103 }
104
105 if queues.is_empty() {
106 Ok(None)
107 } else {
108 Ok(Some(queues))
109 }
110 }
111
112 fn spin_try_recv<'a>(
113 &'a self,
114 ) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
115 let queues = loop {
116 if let Some(guard) = self.queues.try_lock() {
117 break guard;
118 }
119 std::hint::spin_loop();
120 };
121
122 let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
123 if queues.is_empty() && sender_count == 0 {
124 return Err(crate::queue::RecvError);
125 }
126
127 if queues.is_empty() {
128 Ok(None)
129 } else {
130 Ok(Some(queues))
131 }
132 }
133}
134
135#[doc(hidden)]
136pub struct PriorityQueueSender<T> {
137 state: Arc<PriorityQueueState<T>>,
138}
139
140impl<T> PriorityQueueSender<T> {
141 fn new(state: Arc<PriorityQueueState<T>>) -> Self {
142 Self { state }
143 }
144
145 pub fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
146 self.state.send(priority, item)?;
147 Ok(())
148 }
149
150 pub fn spin_send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
151 self.state.spin_send(priority, item)?;
152 Ok(())
153 }
154}
155
156impl<T> Drop for PriorityQueueSender<T> {
157 fn drop(&mut self) {
158 self.state
159 .sender_count
160 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
161 }
162}
163
164#[doc(hidden)]
165pub struct PriorityQueueReceiver<T> {
166 state: Arc<PriorityQueueState<T>>,
167 rand: SmallRng,
168 disconnected: bool,
169}
170
171impl<T> Clone for PriorityQueueReceiver<T> {
172 fn clone(&self) -> Self {
173 self.state
174 .receiver_count
175 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
176 Self {
177 state: Arc::clone(&self.state),
178 rand: SmallRng::seed_from_u64(0),
179 disconnected: self.disconnected,
180 }
181 }
182}
183
184#[doc(hidden)]
185pub struct SendError<T>(pub T);
186
187impl<T: fmt::Debug> fmt::Debug for SendError<T> {
188 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189 f.debug_tuple("SendError").field(&self.0).finish()
190 }
191}
192
193#[derive(Debug)]
194#[doc(hidden)]
195pub struct RecvError;
196
197#[allow(dead_code)]
198impl<T> PriorityQueueReceiver<T> {
199 pub fn new() -> (PriorityQueueSender<T>, Self) {
200 let state = PriorityQueueState {
201 queues: parking_lot::Mutex::new(PriorityQueues {
202 high_priority: VecDeque::new(),
203 medium_priority: VecDeque::new(),
204 low_priority: VecDeque::new(),
205 }),
206 condvar: parking_lot::Condvar::new(),
207 receiver_count: AtomicUsize::new(1),
208 sender_count: AtomicUsize::new(1),
209 };
210 let state = Arc::new(state);
211
212 let sender = PriorityQueueSender::new(Arc::clone(&state));
213
214 let receiver = PriorityQueueReceiver {
215 state,
216 rand: SmallRng::seed_from_u64(0),
217 disconnected: false,
218 };
219
220 (sender, receiver)
221 }
222
223 /// Tries to pop one element from the priority queue without blocking.
224 ///
225 /// This will early return if there are no elements in the queue.
226 ///
227 /// This method is best suited if you only intend to pop one element, for better performance
228 /// on large queues see [`Self::try_iter`]
229 ///
230 /// # Errors
231 ///
232 /// If the sender was dropped
233 pub fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
234 self.pop_inner(false)
235 }
236
237 pub fn spin_try_pop(&mut self) -> Result<Option<T>, RecvError> {
238 use Priority as P;
239
240 let Some(mut queues) = self.state.spin_try_recv()? else {
241 return Ok(None);
242 };
243
244 let high = P::High.weight() * !queues.high_priority.is_empty() as u32;
245 let medium = P::Medium.weight() * !queues.medium_priority.is_empty() as u32;
246 let low = P::Low.weight() * !queues.low_priority.is_empty() as u32;
247 let mut mass = high + medium + low;
248
249 if !queues.high_priority.is_empty() {
250 let flip = self.rand.random_ratio(P::High.weight(), mass);
251 if flip {
252 return Ok(queues.high_priority.pop_front());
253 }
254 mass -= P::High.weight();
255 }
256
257 if !queues.medium_priority.is_empty() {
258 let flip = self.rand.random_ratio(P::Medium.weight(), mass);
259 if flip {
260 return Ok(queues.medium_priority.pop_front());
261 }
262 mass -= P::Medium.weight();
263 }
264
265 if !queues.low_priority.is_empty() {
266 let flip = self.rand.random_ratio(P::Low.weight(), mass);
267 if flip {
268 return Ok(queues.low_priority.pop_front());
269 }
270 }
271
272 Ok(None)
273 }
274
275 /// Pops an element from the priority queue blocking if necessary.
276 ///
277 /// This method is best suited if you only intend to pop one element, for better performance
278 /// on large queues see [`Self::iter``]
279 ///
280 /// # Errors
281 ///
282 /// If the sender was dropped
283 pub fn pop(&mut self) -> Result<T, RecvError> {
284 self.pop_inner(true).map(|e| e.unwrap())
285 }
286
287 /// Returns an iterator over the elements of the queue
288 /// this iterator will end when all elements have been consumed and will not wait for new ones.
289 pub fn try_iter(self) -> TryIter<T> {
290 TryIter {
291 receiver: self,
292 ended: false,
293 }
294 }
295
296 /// Returns an iterator over the elements of the queue
297 /// this iterator will wait for new elements if the queue is empty.
298 pub fn iter(self) -> Iter<T> {
299 Iter(self)
300 }
301
302 #[inline(always)]
303 // algorithm is the loaded die from biased coin from
304 // https://www.keithschwarz.com/darts-dice-coins/
305 fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
306 use Priority as P;
307
308 let mut queues = if !block {
309 let Some(queues) = self.state.try_recv()? else {
310 return Ok(None);
311 };
312 queues
313 } else {
314 self.state.recv()?
315 };
316
317 let high = P::High.weight() * !queues.high_priority.is_empty() as u32;
318 let medium = P::Medium.weight() * !queues.medium_priority.is_empty() as u32;
319 let low = P::Low.weight() * !queues.low_priority.is_empty() as u32;
320 let mut mass = high + medium + low; //%
321
322 if !queues.high_priority.is_empty() {
323 let flip = self.rand.random_ratio(P::High.weight(), mass);
324 if flip {
325 return Ok(queues.high_priority.pop_front());
326 }
327 mass -= P::High.weight();
328 }
329
330 if !queues.medium_priority.is_empty() {
331 let flip = self.rand.random_ratio(P::Medium.weight(), mass);
332 if flip {
333 return Ok(queues.medium_priority.pop_front());
334 }
335 mass -= P::Medium.weight();
336 }
337
338 if !queues.low_priority.is_empty() {
339 let flip = self.rand.random_ratio(P::Low.weight(), mass);
340 if flip {
341 return Ok(queues.low_priority.pop_front());
342 }
343 }
344
345 Ok(None)
346 }
347}
348
349impl<T> Drop for PriorityQueueReceiver<T> {
350 fn drop(&mut self) {
351 self.state
352 .receiver_count
353 .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
354 }
355}
356
357#[doc(hidden)]
358pub struct Iter<T>(PriorityQueueReceiver<T>);
359impl<T> Iterator for Iter<T> {
360 type Item = T;
361
362 fn next(&mut self) -> Option<Self::Item> {
363 self.0.pop().ok()
364 }
365}
366impl<T> FusedIterator for Iter<T> {}
367
368#[doc(hidden)]
369pub struct TryIter<T> {
370 receiver: PriorityQueueReceiver<T>,
371 ended: bool,
372}
373impl<T> Iterator for TryIter<T> {
374 type Item = Result<T, RecvError>;
375
376 fn next(&mut self) -> Option<Self::Item> {
377 if self.ended {
378 return None;
379 }
380
381 let res = self.receiver.try_pop();
382 self.ended = res.is_err();
383
384 res.transpose()
385 }
386}
387impl<T> FusedIterator for TryIter<T> {}
388
389#[cfg(test)]
390mod tests {
391 use collections::HashSet;
392
393 use super::*;
394
395 #[test]
396 fn all_tasks_get_yielded() {
397 let (tx, mut rx) = PriorityQueueReceiver::new();
398 tx.send(Priority::Medium, 20).unwrap();
399 tx.send(Priority::High, 30).unwrap();
400 tx.send(Priority::Low, 10).unwrap();
401 tx.send(Priority::Medium, 21).unwrap();
402 tx.send(Priority::High, 31).unwrap();
403
404 drop(tx);
405
406 assert_eq!(
407 rx.iter().collect::<HashSet<_>>(),
408 [30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
409 )
410 }
411
412 #[test]
413 fn new_high_prio_task_get_scheduled_quickly() {
414 let (tx, mut rx) = PriorityQueueReceiver::new();
415 for _ in 0..100 {
416 tx.send(Priority::Low, 1).unwrap();
417 }
418
419 assert_eq!(rx.pop().unwrap(), 1);
420 tx.send(Priority::High, 3).unwrap();
421 assert_eq!(rx.pop().unwrap(), 3);
422 assert_eq!(rx.pop().unwrap(), 1);
423 }
424}