replayable.rs

  1use std::{
  2    sync::{
  3        Arc, Mutex,
  4        atomic::{AtomicBool, Ordering},
  5    },
  6    time::Duration,
  7};
  8
  9use crossbeam::queue::ArrayQueue;
 10use rodio::{ChannelCount, Sample, SampleRate, Source};
 11
 12#[derive(Debug, thiserror::Error)]
 13#[error("Replay duration is too short must be >= 100ms")]
 14pub struct ReplayDurationTooShort;
 15
 16pub fn replayable<S: Source>(
 17    source: S,
 18    duration: Duration,
 19) -> Result<(Replay, Replayable<S>), ReplayDurationTooShort> {
 20    if duration < Duration::from_millis(100) {
 21        return Err(ReplayDurationTooShort);
 22    }
 23
 24    let samples_per_second = source.sample_rate().get() as usize * source.channels().get() as usize;
 25    let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
 26    let samples_to_queue =
 27        (samples_to_queue as usize).next_multiple_of(source.channels().get().into());
 28
 29    let chunk_size =
 30        (samples_per_second.div_ceil(10)).next_multiple_of(source.channels().get() as usize);
 31    let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
 32
 33    let is_active = Arc::new(AtomicBool::new(true));
 34    let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
 35    Ok((
 36        Replay {
 37            rx: Arc::clone(&queue),
 38            buffer: Vec::new().into_iter(),
 39            sleep_duration: duration / 2,
 40            sample_rate: source.sample_rate(),
 41            channel_count: source.channels(),
 42            source_is_active: is_active.clone(),
 43        },
 44        Replayable {
 45            tx: queue,
 46            inner: source,
 47            buffer: Vec::with_capacity(chunk_size),
 48            chunk_size,
 49            is_active,
 50        },
 51    ))
 52}
 53
 54/// constant source, only works on a single span
 55#[derive(Debug)]
 56struct ReplayQueue {
 57    inner: ArrayQueue<Vec<Sample>>,
 58    normal_chunk_len: usize,
 59    /// The last chunk in the queue may be smaller than
 60    /// the normal chunk size. This is always equal to the
 61    /// size of the last element in the queue.
 62    /// (so normally chunk_size)
 63    last_chunk: Mutex<Vec<Sample>>,
 64}
 65
 66impl ReplayQueue {
 67    fn new(queue_len: usize, chunk_size: usize) -> Self {
 68        Self {
 69            inner: ArrayQueue::new(queue_len),
 70            normal_chunk_len: chunk_size,
 71            last_chunk: Mutex::new(Vec::new()),
 72        }
 73    }
 74    /// Returns the length in samples
 75    fn len(&self) -> usize {
 76        self.inner.len().saturating_sub(1) * self.normal_chunk_len
 77            + self
 78                .last_chunk
 79                .lock()
 80                .expect("Self::push_last can not poison this lock")
 81                .len()
 82    }
 83
 84    fn pop(&self) -> Option<Vec<Sample>> {
 85        self.inner.pop() // removes element that was inserted first
 86    }
 87
 88    fn push_last(&self, mut samples: Vec<Sample>) {
 89        let mut last_chunk = self
 90            .last_chunk
 91            .lock()
 92            .expect("Self::len can not poison this lock");
 93        std::mem::swap(&mut *last_chunk, &mut samples);
 94    }
 95
 96    fn push_normal(&self, samples: Vec<Sample>) {
 97        let _pushed_out_of_ringbuf = self.inner.force_push(samples);
 98    }
 99}
100
101/// constant source, only works on a single span
102#[derive(Debug)]
103pub struct Replayable<S: Source> {
104    inner: S,
105    buffer: Vec<Sample>,
106    chunk_size: usize,
107    tx: Arc<ReplayQueue>,
108    is_active: Arc<AtomicBool>,
109}
110
111impl<S: Source> Iterator for Replayable<S> {
112    type Item = Sample;
113
114    fn next(&mut self) -> Option<Self::Item> {
115        if let Some(sample) = self.inner.next() {
116            self.buffer.push(sample);
117            // If the buffer is full send it
118            if self.buffer.len() == self.chunk_size {
119                self.tx.push_normal(std::mem::take(&mut self.buffer));
120            }
121            Some(sample)
122        } else {
123            let last_chunk = std::mem::take(&mut self.buffer);
124            self.tx.push_last(last_chunk);
125            self.is_active.store(false, Ordering::Relaxed);
126            None
127        }
128    }
129
130    fn size_hint(&self) -> (usize, Option<usize>) {
131        self.inner.size_hint()
132    }
133}
134
135impl<S: Source> Source for Replayable<S> {
136    fn current_span_len(&self) -> Option<usize> {
137        self.inner.current_span_len()
138    }
139
140    fn channels(&self) -> ChannelCount {
141        self.inner.channels()
142    }
143
144    fn sample_rate(&self) -> SampleRate {
145        self.inner.sample_rate()
146    }
147
148    fn total_duration(&self) -> Option<Duration> {
149        self.inner.total_duration()
150    }
151}
152
153/// constant source, only works on a single span
154#[derive(Debug)]
155pub struct Replay {
156    rx: Arc<ReplayQueue>,
157    buffer: std::vec::IntoIter<Sample>,
158    sleep_duration: Duration,
159    sample_rate: SampleRate,
160    channel_count: ChannelCount,
161    source_is_active: Arc<AtomicBool>,
162}
163
164impl Replay {
165    pub fn source_is_active(&self) -> bool {
166        // - source could return None and not drop
167        // - source could be dropped before returning None
168        self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
169    }
170
171    /// Duration of what is in the buffer and can be returned without blocking.
172    pub fn duration_ready(&self) -> Duration {
173        let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
174
175        let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
176        Duration::from_secs_f64(seconds_queued)
177    }
178
179    /// Number of samples in the buffer and can be returned without blocking.
180    pub fn samples_ready(&self) -> usize {
181        self.rx.len() + self.buffer.len()
182    }
183}
184
185impl Iterator for Replay {
186    type Item = Sample;
187
188    fn next(&mut self) -> Option<Self::Item> {
189        if let Some(sample) = self.buffer.next() {
190            return Some(sample);
191        }
192
193        loop {
194            if let Some(new_buffer) = self.rx.pop() {
195                self.buffer = new_buffer.into_iter();
196                return self.buffer.next();
197            }
198
199            if !self.source_is_active() {
200                return None;
201            }
202
203            // The queue does not support blocking on a next item. We want this queue as it
204            // is quite fast and provides a fixed size. We know how many samples are in a
205            // buffer so if we do not get one now we must be getting one after `sleep_duration`.
206            std::thread::sleep(self.sleep_duration);
207        }
208    }
209
210    fn size_hint(&self) -> (usize, Option<usize>) {
211        ((self.rx.len() + self.buffer.len()), None)
212    }
213}
214
215impl Source for Replay {
216    fn current_span_len(&self) -> Option<usize> {
217        None // source is not compatible with spans
218    }
219
220    fn channels(&self) -> ChannelCount {
221        self.channel_count
222    }
223
224    fn sample_rate(&self) -> SampleRate {
225        self.sample_rate
226    }
227
228    fn total_duration(&self) -> Option<Duration> {
229        None
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use rodio::{nz, static_buffer::StaticSamplesBuffer};
236
237    use super::*;
238    use crate::{
239        RodioExt,
240        rodio_ext::tests::{SAMPLES, test_source},
241    };
242
243    #[test]
244    fn continues_after_history() {
245        let input = test_source();
246
247        let (mut replay, mut source) = input
248            .replayable(Duration::from_secs(3))
249            .expect("longer than 100ms");
250
251        source.by_ref().take(3).count();
252        let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
253        assert_eq!(&yielded, &SAMPLES[0..3],);
254
255        source.count();
256        let yielded: Vec<Sample> = replay.collect();
257        assert_eq!(&yielded, &SAMPLES[3..5],);
258    }
259
260    #[test]
261    fn keeps_only_latest() {
262        let input = test_source();
263
264        let (mut replay, mut source) = input
265            .replayable(Duration::from_secs(2))
266            .expect("longer than 100ms");
267
268        source.by_ref().take(5).count(); // get all items but do not end the source
269        let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
270        assert_eq!(&yielded, &SAMPLES[3..5]);
271        source.count(); // exhaust source
272        assert_eq!(replay.next(), None);
273    }
274
275    #[test]
276    fn keeps_correct_amount_of_seconds() {
277        let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
278
279        let (replay, mut source) = input
280            .replayable(Duration::from_secs(2))
281            .expect("longer than 100ms");
282
283        // exhaust but do not yet end source
284        source.by_ref().take(40_000).count();
285
286        // take all samples we can without blocking
287        let ready = replay.samples_ready();
288        let n_yielded = replay.take_samples(ready).count();
289
290        let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
291        let margin = 16_000 / 10; // 100ms
292        assert!(n_yielded as u32 >= max - margin);
293    }
294
295    #[test]
296    fn samples_ready() {
297        let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
298        let (mut replay, source) = input
299            .replayable(Duration::from_secs(2))
300            .expect("longer than 100ms");
301        assert_eq!(replay.by_ref().samples_ready(), 0);
302
303        source.take(8000).count(); // half a second
304        let margin = 16_000 / 10; // 100ms
305        let ready = replay.samples_ready();
306        assert!(ready >= 8000 - margin);
307    }
308}