queue.rs

  1use std::{
  2    fmt,
  3    iter::FusedIterator,
  4    sync::{Arc, atomic::AtomicUsize},
  5};
  6
  7use rand::{Rng, SeedableRng, rngs::SmallRng};
  8
  9use crate::Priority;
 10
 11struct PriorityQueues<T> {
 12    high_priority: Vec<T>,
 13    medium_priority: Vec<T>,
 14    low_priority: Vec<T>,
 15}
 16
 17impl<T> PriorityQueues<T> {
 18    fn is_empty(&self) -> bool {
 19        self.high_priority.is_empty()
 20            && self.medium_priority.is_empty()
 21            && self.low_priority.is_empty()
 22    }
 23}
 24
 25struct PriorityQueueState<T> {
 26    queues: parking_lot::Mutex<PriorityQueues<T>>,
 27    condvar: parking_lot::Condvar,
 28    receiver_count: AtomicUsize,
 29    sender_count: AtomicUsize,
 30}
 31
 32impl<T> PriorityQueueState<T> {
 33    fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
 34        if self
 35            .receiver_count
 36            .load(std::sync::atomic::Ordering::Relaxed)
 37            == 0
 38        {
 39            return Err(SendError(item));
 40        }
 41
 42        let mut queues = self.queues.lock();
 43        match priority {
 44            Priority::Realtime(_) => unreachable!(),
 45            Priority::High => queues.high_priority.push(item),
 46            Priority::Medium => queues.medium_priority.push(item),
 47            Priority::Low => queues.low_priority.push(item),
 48        };
 49        self.condvar.notify_one();
 50        Ok(())
 51    }
 52
 53    fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
 54        let mut queues = self.queues.lock();
 55
 56        let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
 57        if queues.is_empty() && sender_count == 0 {
 58            return Err(crate::queue::RecvError);
 59        }
 60
 61        // parking_lot doesn't do spurious wakeups so an if is fine
 62        if queues.is_empty() {
 63            self.condvar.wait(&mut queues);
 64        }
 65
 66        Ok(queues)
 67    }
 68
 69    fn try_recv<'a>(
 70        &'a self,
 71    ) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
 72        let mut queues = self.queues.lock();
 73
 74        let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
 75        if queues.is_empty() && sender_count == 0 {
 76            return Err(crate::queue::RecvError);
 77        }
 78
 79        if queues.is_empty() {
 80            Ok(None)
 81        } else {
 82            Ok(Some(queues))
 83        }
 84    }
 85}
 86
 87pub(crate) struct PriorityQueueSender<T> {
 88    state: Arc<PriorityQueueState<T>>,
 89}
 90
 91impl<T> PriorityQueueSender<T> {
 92    fn new(state: Arc<PriorityQueueState<T>>) -> Self {
 93        Self { state }
 94    }
 95
 96    pub(crate) fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
 97        self.state.send(priority, item)?;
 98        Ok(())
 99    }
100}
101
102impl<T> Drop for PriorityQueueSender<T> {
103    fn drop(&mut self) {
104        self.state
105            .sender_count
106            .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
107    }
108}
109
110pub(crate) struct PriorityQueueReceiver<T> {
111    state: Arc<PriorityQueueState<T>>,
112    rand: SmallRng,
113    disconnected: bool,
114}
115
116impl<T> Clone for PriorityQueueReceiver<T> {
117    fn clone(&self) -> Self {
118        self.state
119            .receiver_count
120            .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
121        Self {
122            state: Arc::clone(&self.state),
123            rand: SmallRng::seed_from_u64(0),
124            disconnected: self.disconnected,
125        }
126    }
127}
128
129pub(crate) struct SendError<T>(T);
130
131impl<T: fmt::Debug> fmt::Debug for SendError<T> {
132    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133        f.debug_tuple("SendError").field(&self.0).finish()
134    }
135}
136
137#[derive(Debug)]
138pub(crate) struct RecvError;
139
140#[allow(dead_code)]
141impl<T> PriorityQueueReceiver<T> {
142    pub(crate) fn new() -> (PriorityQueueSender<T>, Self) {
143        let state = PriorityQueueState {
144            queues: parking_lot::Mutex::new(PriorityQueues {
145                high_priority: Vec::new(),
146                medium_priority: Vec::new(),
147                low_priority: Vec::new(),
148            }),
149            condvar: parking_lot::Condvar::new(),
150            receiver_count: AtomicUsize::new(1),
151            sender_count: AtomicUsize::new(1),
152        };
153        let state = Arc::new(state);
154
155        let sender = PriorityQueueSender::new(Arc::clone(&state));
156
157        let receiver = PriorityQueueReceiver {
158            state,
159            rand: SmallRng::seed_from_u64(0),
160            disconnected: false,
161        };
162
163        (sender, receiver)
164    }
165
166    /// Tries to pop one element from the priority queue without blocking.
167    ///
168    /// This will early return if there are no elements in the queue.
169    ///
170    /// This method is best suited if you only intend to pop one element, for better performance
171    /// on large queues see [`Self::try_iter`]
172    ///
173    /// # Errors
174    ///
175    /// If the sender was dropped
176    pub(crate) fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
177        self.pop_inner(false)
178    }
179
180    /// Pops an element from the priority queue blocking if necessary.
181    ///
182    /// This method is best suited if you only intend to pop one element, for better performance
183    /// on large queues see [`Self::iter``]
184    ///
185    /// # Errors
186    ///
187    /// If the sender was dropped
188    pub(crate) fn pop(&mut self) -> Result<T, RecvError> {
189        self.pop_inner(true).map(|e| e.unwrap())
190    }
191
192    /// Returns an iterator over the elements of the queue
193    /// this iterator will end when all elements have been consumed and will not wait for new ones.
194    pub(crate) fn try_iter(self) -> TryIter<T> {
195        TryIter {
196            receiver: self,
197            ended: false,
198        }
199    }
200
201    /// Returns an iterator over the elements of the queue
202    /// this iterator will wait for new elements if the queue is empty.
203    pub(crate) fn iter(self) -> Iter<T> {
204        Iter(self)
205    }
206
207    #[inline(always)]
208    // algorithm is the loaded die from biased coin from
209    // https://www.keithschwarz.com/darts-dice-coins/
210    fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
211        use Priority as P;
212
213        let mut queues = if !block {
214            let Some(queues) = self.state.try_recv()? else {
215                return Ok(None);
216            };
217            queues
218        } else {
219            self.state.recv()?
220        };
221
222        let high = P::High.probability() * !queues.high_priority.is_empty() as u32;
223        let medium = P::Medium.probability() * !queues.medium_priority.is_empty() as u32;
224        let low = P::Low.probability() * !queues.low_priority.is_empty() as u32;
225        let mut mass = high + medium + low; //%
226
227        if !queues.high_priority.is_empty() {
228            let flip = self.rand.random_ratio(P::High.probability(), mass);
229            if flip {
230                return Ok(queues.high_priority.pop());
231            }
232            mass -= P::High.probability();
233        }
234
235        if !queues.medium_priority.is_empty() {
236            let flip = self.rand.random_ratio(P::Medium.probability(), mass);
237            if flip {
238                return Ok(queues.medium_priority.pop());
239            }
240            mass -= P::Medium.probability();
241        }
242
243        if !queues.low_priority.is_empty() {
244            let flip = self.rand.random_ratio(P::Low.probability(), mass);
245            if flip {
246                return Ok(queues.low_priority.pop());
247            }
248        }
249
250        Ok(None)
251    }
252}
253
254impl<T> Drop for PriorityQueueReceiver<T> {
255    fn drop(&mut self) {
256        self.state
257            .receiver_count
258            .fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
259    }
260}
261
262/// If None is returned the sender disconnected
263pub(crate) struct Iter<T>(PriorityQueueReceiver<T>);
264impl<T> Iterator for Iter<T> {
265    type Item = T;
266
267    fn next(&mut self) -> Option<Self::Item> {
268        self.0.pop_inner(true).ok().flatten()
269    }
270}
271impl<T> FusedIterator for Iter<T> {}
272
273/// If None is returned there are no more elements in the queue
274pub(crate) struct TryIter<T> {
275    receiver: PriorityQueueReceiver<T>,
276    ended: bool,
277}
278impl<T> Iterator for TryIter<T> {
279    type Item = Result<T, RecvError>;
280
281    fn next(&mut self) -> Option<Self::Item> {
282        if self.ended {
283            return None;
284        }
285
286        let res = self.receiver.pop_inner(false);
287        self.ended = res.is_err();
288
289        res.transpose()
290    }
291}
292impl<T> FusedIterator for TryIter<T> {}
293
294#[cfg(test)]
295mod tests {
296    use collections::HashSet;
297
298    use super::*;
299
300    #[test]
301    fn all_tasks_get_yielded() {
302        let (tx, mut rx) = PriorityQueueReceiver::new();
303        tx.send(Priority::Medium, 20).unwrap();
304        tx.send(Priority::High, 30).unwrap();
305        tx.send(Priority::Low, 10).unwrap();
306        tx.send(Priority::Medium, 21).unwrap();
307        tx.send(Priority::High, 31).unwrap();
308
309        drop(tx);
310
311        assert_eq!(
312            rx.iter().collect::<HashSet<_>>(),
313            [30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
314        )
315    }
316
317    #[test]
318    fn new_high_prio_task_get_scheduled_quickly() {
319        let (tx, mut rx) = PriorityQueueReceiver::new();
320        for _ in 0..100 {
321            tx.send(Priority::Low, 1).unwrap();
322        }
323
324        assert_eq!(rx.pop().unwrap(), 1);
325        tx.send(Priority::High, 3).unwrap();
326        assert_eq!(rx.pop().unwrap(), 3);
327        assert_eq!(rx.pop().unwrap(), 1);
328    }
329}