queue.rs

  1use std::{
  2    collections::VecDeque,
  3    fmt,
  4    iter::FusedIterator,
  5    sync::{Arc, atomic::AtomicUsize},
  6};
  7
  8use rand::{Rng, SeedableRng, rngs::SmallRng};
  9
 10use crate::Priority;
 11
 12struct PriorityQueues<T> {
 13    high_priority: VecDeque<T>,
 14    medium_priority: VecDeque<T>,
 15    low_priority: VecDeque<T>,
 16}
 17
 18impl<T> PriorityQueues<T> {
 19    fn is_empty(&self) -> bool {
 20        self.high_priority.is_empty()
 21            && self.medium_priority.is_empty()
 22            && self.low_priority.is_empty()
 23    }
 24}
 25
 26struct PriorityQueueState<T> {
 27    queues: parking_lot::Mutex<PriorityQueues<T>>,
 28    condvar: parking_lot::Condvar,
 29    receiver_count: AtomicUsize,
 30    sender_count: AtomicUsize,
 31}
 32
 33impl<T> PriorityQueueState<T> {
 34    fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
 35        if self
 36            .receiver_count
 37            .load(std::sync::atomic::Ordering::Relaxed)
 38            == 0
 39        {
 40            return Err(SendError(item));
 41        }
 42
 43        let mut queues = self.queues.lock();
 44        Self::push(&mut queues, priority, item);
 45        self.condvar.notify_one();
 46        Ok(())
 47    }
 48
 49    fn spin_send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
 50        if self
 51            .receiver_count
 52            .load(std::sync::atomic::Ordering::Relaxed)
 53            == 0
 54        {
 55            return Err(SendError(item));
 56        }
 57
 58        let mut queues = loop {
 59            if let Some(guard) = self.queues.try_lock() {
 60                break guard;
 61            }
 62            std::hint::spin_loop();
 63        };
 64        Self::push(&mut queues, priority, item);
 65        self.condvar.notify_one();
 66        Ok(())
 67    }
 68
 69    fn push(queues: &mut PriorityQueues<T>, priority: Priority, item: T) {
 70        match priority {
 71            Priority::RealtimeAudio => unreachable!(
 72                "Realtime audio priority runs on a dedicated thread and is never queued"
 73            ),
 74            Priority::High => queues.high_priority.push_back(item),
 75            Priority::Medium => queues.medium_priority.push_back(item),
 76            Priority::Low => queues.low_priority.push_back(item),
 77        };
 78    }
 79
 80    fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
 81        let mut queues = self.queues.lock();
 82
 83        let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
 84        if queues.is_empty() && sender_count == 0 {
 85            return Err(crate::queue::RecvError);
 86        }
 87
 88        while queues.is_empty() {
 89            self.condvar.wait(&mut queues);
 90        }
 91
 92        Ok(queues)
 93    }
 94
 95    fn try_recv<'a>(
 96        &'a self,
 97    ) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
 98        let mut queues = self.queues.lock();
 99
100        let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
101        if queues.is_empty() && sender_count == 0 {
102            return Err(crate::queue::RecvError);
103        }
104
105        if queues.is_empty() {
106            Ok(None)
107        } else {
108            Ok(Some(queues))
109        }
110    }
111
112    fn spin_try_recv<'a>(
113        &'a self,
114    ) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
115        let queues = loop {
116            if let Some(guard) = self.queues.try_lock() {
117                break guard;
118            }
119            std::hint::spin_loop();
120        };
121
122        let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
123        if queues.is_empty() && sender_count == 0 {
124            return Err(crate::queue::RecvError);
125        }
126
127        if queues.is_empty() {
128            Ok(None)
129        } else {
130            Ok(Some(queues))
131        }
132    }
133}
134
135#[doc(hidden)]
136pub struct PriorityQueueSender<T> {
137    state: Arc<PriorityQueueState<T>>,
138}
139
140impl<T> PriorityQueueSender<T> {
141    fn new(state: Arc<PriorityQueueState<T>>) -> Self {
142        Self { state }
143    }
144
145    pub fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
146        self.state.send(priority, item)?;
147        Ok(())
148    }
149
150    pub fn spin_send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
151        self.state.spin_send(priority, item)?;
152        Ok(())
153    }
154}
155
156impl<T> Drop for PriorityQueueSender<T> {
157    fn drop(&mut self) {
158        self.state
159            .sender_count
160            .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
161    }
162}
163
164#[doc(hidden)]
165pub struct PriorityQueueReceiver<T> {
166    state: Arc<PriorityQueueState<T>>,
167    rand: SmallRng,
168    disconnected: bool,
169}
170
171impl<T> Clone for PriorityQueueReceiver<T> {
172    fn clone(&self) -> Self {
173        self.state
174            .receiver_count
175            .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
176        Self {
177            state: Arc::clone(&self.state),
178            rand: SmallRng::seed_from_u64(0),
179            disconnected: self.disconnected,
180        }
181    }
182}
183
184#[doc(hidden)]
185pub struct SendError<T>(pub T);
186
187impl<T: fmt::Debug> fmt::Debug for SendError<T> {
188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189        f.debug_tuple("SendError").field(&self.0).finish()
190    }
191}
192
193#[derive(Debug)]
194#[doc(hidden)]
195pub struct RecvError;
196
197#[allow(dead_code)]
198impl<T> PriorityQueueReceiver<T> {
199    pub fn new() -> (PriorityQueueSender<T>, Self) {
200        let state = PriorityQueueState {
201            queues: parking_lot::Mutex::new(PriorityQueues {
202                high_priority: VecDeque::new(),
203                medium_priority: VecDeque::new(),
204                low_priority: VecDeque::new(),
205            }),
206            condvar: parking_lot::Condvar::new(),
207            receiver_count: AtomicUsize::new(1),
208            sender_count: AtomicUsize::new(1),
209        };
210        let state = Arc::new(state);
211
212        let sender = PriorityQueueSender::new(Arc::clone(&state));
213
214        let receiver = PriorityQueueReceiver {
215            state,
216            rand: SmallRng::seed_from_u64(0),
217            disconnected: false,
218        };
219
220        (sender, receiver)
221    }
222
223    /// Tries to pop one element from the priority queue without blocking.
224    ///
225    /// This will early return if there are no elements in the queue.
226    ///
227    /// This method is best suited if you only intend to pop one element, for better performance
228    /// on large queues see [`Self::try_iter`]
229    ///
230    /// # Errors
231    ///
232    /// If the sender was dropped
233    pub fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
234        self.pop_inner(false)
235    }
236
237    pub fn spin_try_pop(&mut self) -> Result<Option<T>, RecvError> {
238        use Priority as P;
239
240        let Some(mut queues) = self.state.spin_try_recv()? else {
241            return Ok(None);
242        };
243
244        let high = P::High.weight() * !queues.high_priority.is_empty() as u32;
245        let medium = P::Medium.weight() * !queues.medium_priority.is_empty() as u32;
246        let low = P::Low.weight() * !queues.low_priority.is_empty() as u32;
247        let mut mass = high + medium + low;
248
249        if !queues.high_priority.is_empty() {
250            let flip = self.rand.random_ratio(P::High.weight(), mass);
251            if flip {
252                return Ok(queues.high_priority.pop_front());
253            }
254            mass -= P::High.weight();
255        }
256
257        if !queues.medium_priority.is_empty() {
258            let flip = self.rand.random_ratio(P::Medium.weight(), mass);
259            if flip {
260                return Ok(queues.medium_priority.pop_front());
261            }
262            mass -= P::Medium.weight();
263        }
264
265        if !queues.low_priority.is_empty() {
266            let flip = self.rand.random_ratio(P::Low.weight(), mass);
267            if flip {
268                return Ok(queues.low_priority.pop_front());
269            }
270        }
271
272        Ok(None)
273    }
274
275    /// Pops an element from the priority queue blocking if necessary.
276    ///
277    /// This method is best suited if you only intend to pop one element, for better performance
278    /// on large queues see [`Self::iter``]
279    ///
280    /// # Errors
281    ///
282    /// If the sender was dropped
283    pub fn pop(&mut self) -> Result<T, RecvError> {
284        self.pop_inner(true).map(|e| e.unwrap())
285    }
286
287    /// Returns an iterator over the elements of the queue
288    /// this iterator will end when all elements have been consumed and will not wait for new ones.
289    pub fn try_iter(self) -> TryIter<T> {
290        TryIter {
291            receiver: self,
292            ended: false,
293        }
294    }
295
296    /// Returns an iterator over the elements of the queue
297    /// this iterator will wait for new elements if the queue is empty.
298    pub fn iter(self) -> Iter<T> {
299        Iter(self)
300    }
301
302    #[inline(always)]
303    // algorithm is the loaded die from biased coin from
304    // https://www.keithschwarz.com/darts-dice-coins/
305    fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
306        use Priority as P;
307
308        let mut queues = if !block {
309            let Some(queues) = self.state.try_recv()? else {
310                return Ok(None);
311            };
312            queues
313        } else {
314            self.state.recv()?
315        };
316
317        let high = P::High.weight() * !queues.high_priority.is_empty() as u32;
318        let medium = P::Medium.weight() * !queues.medium_priority.is_empty() as u32;
319        let low = P::Low.weight() * !queues.low_priority.is_empty() as u32;
320        let mut mass = high + medium + low; //%
321
322        if !queues.high_priority.is_empty() {
323            let flip = self.rand.random_ratio(P::High.weight(), mass);
324            if flip {
325                return Ok(queues.high_priority.pop_front());
326            }
327            mass -= P::High.weight();
328        }
329
330        if !queues.medium_priority.is_empty() {
331            let flip = self.rand.random_ratio(P::Medium.weight(), mass);
332            if flip {
333                return Ok(queues.medium_priority.pop_front());
334            }
335            mass -= P::Medium.weight();
336        }
337
338        if !queues.low_priority.is_empty() {
339            let flip = self.rand.random_ratio(P::Low.weight(), mass);
340            if flip {
341                return Ok(queues.low_priority.pop_front());
342            }
343        }
344
345        Ok(None)
346    }
347}
348
349impl<T> Drop for PriorityQueueReceiver<T> {
350    fn drop(&mut self) {
351        self.state
352            .receiver_count
353            .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
354    }
355}
356
357#[doc(hidden)]
358pub struct Iter<T>(PriorityQueueReceiver<T>);
359impl<T> Iterator for Iter<T> {
360    type Item = T;
361
362    fn next(&mut self) -> Option<Self::Item> {
363        self.0.pop().ok()
364    }
365}
366impl<T> FusedIterator for Iter<T> {}
367
368#[doc(hidden)]
369pub struct TryIter<T> {
370    receiver: PriorityQueueReceiver<T>,
371    ended: bool,
372}
373impl<T> Iterator for TryIter<T> {
374    type Item = Result<T, RecvError>;
375
376    fn next(&mut self) -> Option<Self::Item> {
377        if self.ended {
378            return None;
379        }
380
381        let res = self.receiver.try_pop();
382        self.ended = res.is_err();
383
384        res.transpose()
385    }
386}
387impl<T> FusedIterator for TryIter<T> {}
388
389#[cfg(test)]
390mod tests {
391    use collections::HashSet;
392
393    use super::*;
394
395    #[test]
396    fn all_tasks_get_yielded() {
397        let (tx, mut rx) = PriorityQueueReceiver::new();
398        tx.send(Priority::Medium, 20).unwrap();
399        tx.send(Priority::High, 30).unwrap();
400        tx.send(Priority::Low, 10).unwrap();
401        tx.send(Priority::Medium, 21).unwrap();
402        tx.send(Priority::High, 31).unwrap();
403
404        drop(tx);
405
406        assert_eq!(
407            rx.iter().collect::<HashSet<_>>(),
408            [30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
409        )
410    }
411
412    #[test]
413    fn new_high_prio_task_get_scheduled_quickly() {
414        let (tx, mut rx) = PriorityQueueReceiver::new();
415        for _ in 0..100 {
416            tx.send(Priority::Low, 1).unwrap();
417        }
418
419        assert_eq!(rx.pop().unwrap(), 1);
420        tx.send(Priority::High, 3).unwrap();
421        assert_eq!(rx.pop().unwrap(), 3);
422        assert_eq!(rx.pop().unwrap(), 1);
423    }
424}