queue.rs

  1use std::{
  2    collections::VecDeque,
  3    fmt,
  4    iter::FusedIterator,
  5    sync::{Arc, atomic::AtomicUsize},
  6};
  7
  8use rand::{Rng, SeedableRng, rngs::SmallRng};
  9
 10use crate::Priority;
 11
 12struct PriorityQueues<T> {
 13    high_priority: VecDeque<T>,
 14    medium_priority: VecDeque<T>,
 15    low_priority: VecDeque<T>,
 16}
 17
 18impl<T> PriorityQueues<T> {
 19    fn is_empty(&self) -> bool {
 20        self.high_priority.is_empty()
 21            && self.medium_priority.is_empty()
 22            && self.low_priority.is_empty()
 23    }
 24}
 25
 26struct PriorityQueueState<T> {
 27    queues: parking_lot::Mutex<PriorityQueues<T>>,
 28    condvar: parking_lot::Condvar,
 29    receiver_count: AtomicUsize,
 30    sender_count: AtomicUsize,
 31}
 32
 33impl<T> PriorityQueueState<T> {
 34    fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
 35        if self
 36            .receiver_count
 37            .load(std::sync::atomic::Ordering::Relaxed)
 38            == 0
 39        {
 40            return Err(SendError(item));
 41        }
 42
 43        let mut queues = self.queues.lock();
 44        match priority {
 45            Priority::RealtimeAudio => unreachable!(
 46                "Realtime audio priority runs on a dedicated thread and is never queued"
 47            ),
 48            Priority::High => queues.high_priority.push_back(item),
 49            Priority::Medium => queues.medium_priority.push_back(item),
 50            Priority::Low => queues.low_priority.push_back(item),
 51        };
 52        self.condvar.notify_one();
 53        Ok(())
 54    }
 55
 56    fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
 57        let mut queues = self.queues.lock();
 58
 59        let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
 60        if queues.is_empty() && sender_count == 0 {
 61            return Err(crate::queue::RecvError);
 62        }
 63
 64        while queues.is_empty() {
 65            self.condvar.wait(&mut queues);
 66        }
 67
 68        Ok(queues)
 69    }
 70
 71    fn try_recv<'a>(
 72        &'a self,
 73    ) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
 74        let mut queues = self.queues.lock();
 75
 76        let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
 77        if queues.is_empty() && sender_count == 0 {
 78            return Err(crate::queue::RecvError);
 79        }
 80
 81        if queues.is_empty() {
 82            Ok(None)
 83        } else {
 84            Ok(Some(queues))
 85        }
 86    }
 87}
 88
 89pub(crate) struct PriorityQueueSender<T> {
 90    state: Arc<PriorityQueueState<T>>,
 91}
 92
 93impl<T> PriorityQueueSender<T> {
 94    fn new(state: Arc<PriorityQueueState<T>>) -> Self {
 95        Self { state }
 96    }
 97
 98    pub(crate) fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
 99        self.state.send(priority, item)?;
100        Ok(())
101    }
102}
103
104impl<T> Drop for PriorityQueueSender<T> {
105    fn drop(&mut self) {
106        self.state
107            .sender_count
108            .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
109    }
110}
111
112pub(crate) struct PriorityQueueReceiver<T> {
113    state: Arc<PriorityQueueState<T>>,
114    rand: SmallRng,
115    disconnected: bool,
116}
117
118impl<T> Clone for PriorityQueueReceiver<T> {
119    fn clone(&self) -> Self {
120        self.state
121            .receiver_count
122            .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
123        Self {
124            state: Arc::clone(&self.state),
125            rand: SmallRng::seed_from_u64(0),
126            disconnected: self.disconnected,
127        }
128    }
129}
130
131pub(crate) struct SendError<T>(T);
132
133impl<T: fmt::Debug> fmt::Debug for SendError<T> {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        f.debug_tuple("SendError").field(&self.0).finish()
136    }
137}
138
139#[derive(Debug)]
140pub(crate) struct RecvError;
141
142#[allow(dead_code)]
143impl<T> PriorityQueueReceiver<T> {
144    pub(crate) fn new() -> (PriorityQueueSender<T>, Self) {
145        let state = PriorityQueueState {
146            queues: parking_lot::Mutex::new(PriorityQueues {
147                high_priority: VecDeque::new(),
148                medium_priority: VecDeque::new(),
149                low_priority: VecDeque::new(),
150            }),
151            condvar: parking_lot::Condvar::new(),
152            receiver_count: AtomicUsize::new(1),
153            sender_count: AtomicUsize::new(1),
154        };
155        let state = Arc::new(state);
156
157        let sender = PriorityQueueSender::new(Arc::clone(&state));
158
159        let receiver = PriorityQueueReceiver {
160            state,
161            rand: SmallRng::seed_from_u64(0),
162            disconnected: false,
163        };
164
165        (sender, receiver)
166    }
167
168    /// Tries to pop one element from the priority queue without blocking.
169    ///
170    /// This will early return if there are no elements in the queue.
171    ///
172    /// This method is best suited if you only intend to pop one element, for better performance
173    /// on large queues see [`Self::try_iter`]
174    ///
175    /// # Errors
176    ///
177    /// If the sender was dropped
178    pub(crate) fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
179        self.pop_inner(false)
180    }
181
182    /// Pops an element from the priority queue blocking if necessary.
183    ///
184    /// This method is best suited if you only intend to pop one element, for better performance
185    /// on large queues see [`Self::iter``]
186    ///
187    /// # Errors
188    ///
189    /// If the sender was dropped
190    pub(crate) fn pop(&mut self) -> Result<T, RecvError> {
191        self.pop_inner(true).map(|e| e.unwrap())
192    }
193
194    /// Returns an iterator over the elements of the queue
195    /// this iterator will end when all elements have been consumed and will not wait for new ones.
196    pub(crate) fn try_iter(self) -> TryIter<T> {
197        TryIter {
198            receiver: self,
199            ended: false,
200        }
201    }
202
203    /// Returns an iterator over the elements of the queue
204    /// this iterator will wait for new elements if the queue is empty.
205    pub(crate) fn iter(self) -> Iter<T> {
206        Iter(self)
207    }
208
209    #[inline(always)]
210    // algorithm is the loaded die from biased coin from
211    // https://www.keithschwarz.com/darts-dice-coins/
212    fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
213        use Priority as P;
214
215        let mut queues = if !block {
216            let Some(queues) = self.state.try_recv()? else {
217                return Ok(None);
218            };
219            queues
220        } else {
221            self.state.recv()?
222        };
223
224        let high = P::High.weight() * !queues.high_priority.is_empty() as u32;
225        let medium = P::Medium.weight() * !queues.medium_priority.is_empty() as u32;
226        let low = P::Low.weight() * !queues.low_priority.is_empty() as u32;
227        let mut mass = high + medium + low; //%
228
229        if !queues.high_priority.is_empty() {
230            let flip = self.rand.random_ratio(P::High.weight(), mass);
231            if flip {
232                return Ok(queues.high_priority.pop_front());
233            }
234            mass -= P::High.weight();
235        }
236
237        if !queues.medium_priority.is_empty() {
238            let flip = self.rand.random_ratio(P::Medium.weight(), mass);
239            if flip {
240                return Ok(queues.medium_priority.pop_front());
241            }
242            mass -= P::Medium.weight();
243        }
244
245        if !queues.low_priority.is_empty() {
246            let flip = self.rand.random_ratio(P::Low.weight(), mass);
247            if flip {
248                return Ok(queues.low_priority.pop_front());
249            }
250        }
251
252        Ok(None)
253    }
254}
255
256impl<T> Drop for PriorityQueueReceiver<T> {
257    fn drop(&mut self) {
258        self.state
259            .receiver_count
260            .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
261    }
262}
263
264/// If None is returned the sender disconnected
265pub(crate) struct Iter<T>(PriorityQueueReceiver<T>);
266impl<T> Iterator for Iter<T> {
267    type Item = T;
268
269    fn next(&mut self) -> Option<Self::Item> {
270        self.0.pop().ok()
271    }
272}
273impl<T> FusedIterator for Iter<T> {}
274
275/// If None is returned there are no more elements in the queue
276pub(crate) struct TryIter<T> {
277    receiver: PriorityQueueReceiver<T>,
278    ended: bool,
279}
280impl<T> Iterator for TryIter<T> {
281    type Item = Result<T, RecvError>;
282
283    fn next(&mut self) -> Option<Self::Item> {
284        if self.ended {
285            return None;
286        }
287
288        let res = self.receiver.try_pop();
289        self.ended = res.is_err();
290
291        res.transpose()
292    }
293}
294impl<T> FusedIterator for TryIter<T> {}
295
296#[cfg(test)]
297mod tests {
298    use collections::HashSet;
299
300    use super::*;
301
302    #[test]
303    fn all_tasks_get_yielded() {
304        let (tx, mut rx) = PriorityQueueReceiver::new();
305        tx.send(Priority::Medium, 20).unwrap();
306        tx.send(Priority::High, 30).unwrap();
307        tx.send(Priority::Low, 10).unwrap();
308        tx.send(Priority::Medium, 21).unwrap();
309        tx.send(Priority::High, 31).unwrap();
310
311        drop(tx);
312
313        assert_eq!(
314            rx.iter().collect::<HashSet<_>>(),
315            [30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
316        )
317    }
318
319    #[test]
320    fn new_high_prio_task_get_scheduled_quickly() {
321        let (tx, mut rx) = PriorityQueueReceiver::new();
322        for _ in 0..100 {
323            tx.send(Priority::Low, 1).unwrap();
324        }
325
326        assert_eq!(rx.pop().unwrap(), 1);
327        tx.send(Priority::High, 3).unwrap();
328        assert_eq!(rx.pop().unwrap(), 3);
329        assert_eq!(rx.pop().unwrap(), 1);
330    }
331}