queue.rs

  1use std::{
  2    fmt,
  3    iter::FusedIterator,
  4    sync::{Arc, atomic::AtomicUsize},
  5};
  6
  7use rand::{Rng, SeedableRng, rngs::SmallRng};
  8
  9use crate::Priority;
 10
 11struct PriorityQueues<T> {
 12    high_priority: Vec<T>,
 13    medium_priority: Vec<T>,
 14    low_priority: Vec<T>,
 15}
 16
 17impl<T> PriorityQueues<T> {
 18    fn is_empty(&self) -> bool {
 19        self.high_priority.is_empty()
 20            && self.medium_priority.is_empty()
 21            && self.low_priority.is_empty()
 22    }
 23}
 24
 25struct PriorityQueueState<T> {
 26    queues: parking_lot::Mutex<PriorityQueues<T>>,
 27    condvar: parking_lot::Condvar,
 28    receiver_count: AtomicUsize,
 29    sender_count: AtomicUsize,
 30}
 31
 32impl<T> PriorityQueueState<T> {
 33    fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
 34        if self
 35            .receiver_count
 36            .load(std::sync::atomic::Ordering::Relaxed)
 37            == 0
 38        {
 39            return Err(SendError(item));
 40        }
 41
 42        let mut queues = self.queues.lock();
 43        match priority {
 44            Priority::Realtime(_) => unreachable!(),
 45            Priority::High => queues.high_priority.push(item),
 46            Priority::Medium => queues.medium_priority.push(item),
 47            Priority::Low => queues.low_priority.push(item),
 48        };
 49        self.condvar.notify_one();
 50        Ok(())
 51    }
 52
 53    fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
 54        let mut queues = self.queues.lock();
 55
 56        let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
 57        if queues.is_empty() && sender_count == 0 {
 58            return Err(crate::queue::RecvError);
 59        }
 60
 61        while queues.is_empty() {
 62            self.condvar.wait(&mut queues);
 63        }
 64
 65        Ok(queues)
 66    }
 67
 68    fn try_recv<'a>(
 69        &'a self,
 70    ) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
 71        let mut queues = self.queues.lock();
 72
 73        let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
 74        if queues.is_empty() && sender_count == 0 {
 75            return Err(crate::queue::RecvError);
 76        }
 77
 78        if queues.is_empty() {
 79            Ok(None)
 80        } else {
 81            Ok(Some(queues))
 82        }
 83    }
 84}
 85
 86pub(crate) struct PriorityQueueSender<T> {
 87    state: Arc<PriorityQueueState<T>>,
 88}
 89
 90impl<T> PriorityQueueSender<T> {
 91    fn new(state: Arc<PriorityQueueState<T>>) -> Self {
 92        Self { state }
 93    }
 94
 95    pub(crate) fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
 96        self.state.send(priority, item)?;
 97        Ok(())
 98    }
 99}
100
101impl<T> Drop for PriorityQueueSender<T> {
102    fn drop(&mut self) {
103        self.state
104            .sender_count
105            .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
106    }
107}
108
109pub(crate) struct PriorityQueueReceiver<T> {
110    state: Arc<PriorityQueueState<T>>,
111    rand: SmallRng,
112    disconnected: bool,
113}
114
115impl<T> Clone for PriorityQueueReceiver<T> {
116    fn clone(&self) -> Self {
117        self.state
118            .receiver_count
119            .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
120        Self {
121            state: Arc::clone(&self.state),
122            rand: SmallRng::seed_from_u64(0),
123            disconnected: self.disconnected,
124        }
125    }
126}
127
128pub(crate) struct SendError<T>(T);
129
130impl<T: fmt::Debug> fmt::Debug for SendError<T> {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        f.debug_tuple("SendError").field(&self.0).finish()
133    }
134}
135
136#[derive(Debug)]
137pub(crate) struct RecvError;
138
139#[allow(dead_code)]
140impl<T> PriorityQueueReceiver<T> {
141    pub(crate) fn new() -> (PriorityQueueSender<T>, Self) {
142        let state = PriorityQueueState {
143            queues: parking_lot::Mutex::new(PriorityQueues {
144                high_priority: Vec::new(),
145                medium_priority: Vec::new(),
146                low_priority: Vec::new(),
147            }),
148            condvar: parking_lot::Condvar::new(),
149            receiver_count: AtomicUsize::new(1),
150            sender_count: AtomicUsize::new(1),
151        };
152        let state = Arc::new(state);
153
154        let sender = PriorityQueueSender::new(Arc::clone(&state));
155
156        let receiver = PriorityQueueReceiver {
157            state,
158            rand: SmallRng::seed_from_u64(0),
159            disconnected: false,
160        };
161
162        (sender, receiver)
163    }
164
165    /// Tries to pop one element from the priority queue without blocking.
166    ///
167    /// This will early return if there are no elements in the queue.
168    ///
169    /// This method is best suited if you only intend to pop one element, for better performance
170    /// on large queues see [`Self::try_iter`]
171    ///
172    /// # Errors
173    ///
174    /// If the sender was dropped
175    pub(crate) fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
176        self.pop_inner(false)
177    }
178
179    /// Pops an element from the priority queue blocking if necessary.
180    ///
181    /// This method is best suited if you only intend to pop one element, for better performance
182    /// on large queues see [`Self::iter``]
183    ///
184    /// # Errors
185    ///
186    /// If the sender was dropped
187    pub(crate) fn pop(&mut self) -> Result<T, RecvError> {
188        self.pop_inner(true).map(|e| e.unwrap())
189    }
190
191    /// Returns an iterator over the elements of the queue
192    /// this iterator will end when all elements have been consumed and will not wait for new ones.
193    pub(crate) fn try_iter(self) -> TryIter<T> {
194        TryIter {
195            receiver: self,
196            ended: false,
197        }
198    }
199
200    /// Returns an iterator over the elements of the queue
201    /// this iterator will wait for new elements if the queue is empty.
202    pub(crate) fn iter(self) -> Iter<T> {
203        Iter(self)
204    }
205
206    #[inline(always)]
207    // algorithm is the loaded die from biased coin from
208    // https://www.keithschwarz.com/darts-dice-coins/
209    fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
210        use Priority as P;
211
212        let mut queues = if !block {
213            let Some(queues) = self.state.try_recv()? else {
214                return Ok(None);
215            };
216            queues
217        } else {
218            self.state.recv()?
219        };
220
221        let high = P::High.probability() * !queues.high_priority.is_empty() as u32;
222        let medium = P::Medium.probability() * !queues.medium_priority.is_empty() as u32;
223        let low = P::Low.probability() * !queues.low_priority.is_empty() as u32;
224        let mut mass = high + medium + low; //%
225
226        if !queues.high_priority.is_empty() {
227            let flip = self.rand.random_ratio(P::High.probability(), mass);
228            if flip {
229                return Ok(queues.high_priority.pop());
230            }
231            mass -= P::High.probability();
232        }
233
234        if !queues.medium_priority.is_empty() {
235            let flip = self.rand.random_ratio(P::Medium.probability(), mass);
236            if flip {
237                return Ok(queues.medium_priority.pop());
238            }
239            mass -= P::Medium.probability();
240        }
241
242        if !queues.low_priority.is_empty() {
243            let flip = self.rand.random_ratio(P::Low.probability(), mass);
244            if flip {
245                return Ok(queues.low_priority.pop());
246            }
247        }
248
249        Ok(None)
250    }
251}
252
253impl<T> Drop for PriorityQueueReceiver<T> {
254    fn drop(&mut self) {
255        self.state
256            .receiver_count
257            .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
258    }
259}
260
261/// If None is returned the sender disconnected
262pub(crate) struct Iter<T>(PriorityQueueReceiver<T>);
263impl<T> Iterator for Iter<T> {
264    type Item = T;
265
266    fn next(&mut self) -> Option<Self::Item> {
267        self.0.pop().ok()
268    }
269}
270impl<T> FusedIterator for Iter<T> {}
271
272/// If None is returned there are no more elements in the queue
273pub(crate) struct TryIter<T> {
274    receiver: PriorityQueueReceiver<T>,
275    ended: bool,
276}
277impl<T> Iterator for TryIter<T> {
278    type Item = Result<T, RecvError>;
279
280    fn next(&mut self) -> Option<Self::Item> {
281        if self.ended {
282            return None;
283        }
284
285        let res = self.receiver.try_pop();
286        self.ended = res.is_err();
287
288        res.transpose()
289    }
290}
291impl<T> FusedIterator for TryIter<T> {}
292
293#[cfg(test)]
294mod tests {
295    use collections::HashSet;
296
297    use super::*;
298
299    #[test]
300    fn all_tasks_get_yielded() {
301        let (tx, mut rx) = PriorityQueueReceiver::new();
302        tx.send(Priority::Medium, 20).unwrap();
303        tx.send(Priority::High, 30).unwrap();
304        tx.send(Priority::Low, 10).unwrap();
305        tx.send(Priority::Medium, 21).unwrap();
306        tx.send(Priority::High, 31).unwrap();
307
308        drop(tx);
309
310        assert_eq!(
311            rx.iter().collect::<HashSet<_>>(),
312            [30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
313        )
314    }
315
316    #[test]
317    fn new_high_prio_task_get_scheduled_quickly() {
318        let (tx, mut rx) = PriorityQueueReceiver::new();
319        for _ in 0..100 {
320            tx.send(Priority::Low, 1).unwrap();
321        }
322
323        assert_eq!(rx.pop().unwrap(), 1);
324        tx.send(Priority::High, 3).unwrap();
325        assert_eq!(rx.pop().unwrap(), 3);
326        assert_eq!(rx.pop().unwrap(), 1);
327    }
328}