rodio_ext.rs

  1use std::{
  2    sync::{
  3        Arc, Mutex,
  4        atomic::{AtomicBool, Ordering},
  5    },
  6    time::Duration,
  7};
  8
  9use crossbeam::queue::ArrayQueue;
 10use rodio::{ChannelCount, Sample, SampleRate, Source};
 11
 12#[derive(Debug, thiserror::Error)]
 13#[error("Replay duration is too short must be >= 100ms")]
 14pub struct ReplayDurationTooShort;
 15
 16pub trait RodioExt: Source + Sized {
 17    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 18    where
 19        F: FnMut(&mut [Sample; N]);
 20    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 21    where
 22        F: FnMut(&[Sample; N]);
 23    fn replayable(
 24        self,
 25        duration: Duration,
 26    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
 27    fn take_samples(self, n: usize) -> TakeSamples<Self>;
 28}
 29
 30impl<S: Source> RodioExt for S {
 31    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 32    where
 33        F: FnMut(&mut [Sample; N]),
 34    {
 35        ProcessBuffer {
 36            inner: self,
 37            callback,
 38            buffer: [0.0; N],
 39            next: N,
 40        }
 41    }
 42    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 43    where
 44        F: FnMut(&[Sample; N]),
 45    {
 46        InspectBuffer {
 47            inner: self,
 48            callback,
 49            buffer: [0.0; N],
 50            free: 0,
 51        }
 52    }
 53    /// Maintains a live replay with a history of at least `duration` seconds.
 54    ///
 55    /// Note:
 56    /// History can be 100ms longer if the source drops before or while the
 57    /// replay is being read
 58    ///
 59    /// # Errors
 60    /// If duration is smaller than 100ms
 61    fn replayable(
 62        self,
 63        duration: Duration,
 64    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
 65        if duration < Duration::from_millis(100) {
 66            return Err(ReplayDurationTooShort);
 67        }
 68
 69        let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
 70        let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
 71        let samples_to_queue =
 72            (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
 73
 74        let chunk_size =
 75            (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
 76        let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
 77
 78        let is_active = Arc::new(AtomicBool::new(true));
 79        let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
 80        Ok((
 81            Replay {
 82                rx: Arc::clone(&queue),
 83                buffer: Vec::new().into_iter(),
 84                sleep_duration: duration / 2,
 85                sample_rate: self.sample_rate(),
 86                channel_count: self.channels(),
 87                source_is_active: is_active.clone(),
 88            },
 89            Replayable {
 90                tx: queue,
 91                inner: self,
 92                buffer: Vec::with_capacity(chunk_size),
 93                chunk_size,
 94                is_active,
 95            },
 96        ))
 97    }
 98    fn take_samples(self, n: usize) -> TakeSamples<S> {
 99        TakeSamples {
100            inner: self,
101            left_to_take: n,
102        }
103    }
104}
105
106pub struct TakeSamples<S> {
107    inner: S,
108    left_to_take: usize,
109}
110
111impl<S: Source> Iterator for TakeSamples<S> {
112    type Item = Sample;
113
114    fn next(&mut self) -> Option<Self::Item> {
115        if self.left_to_take == 0 {
116            None
117        } else {
118            self.left_to_take -= 1;
119            self.inner.next()
120        }
121    }
122
123    fn size_hint(&self) -> (usize, Option<usize>) {
124        (0, Some(self.left_to_take))
125    }
126}
127
128impl<S: Source> Source for TakeSamples<S> {
129    fn current_span_len(&self) -> Option<usize> {
130        None // does not support spans
131    }
132
133    fn channels(&self) -> ChannelCount {
134        self.inner.channels()
135    }
136
137    fn sample_rate(&self) -> SampleRate {
138        self.inner.sample_rate()
139    }
140
141    fn total_duration(&self) -> Option<Duration> {
142        Some(Duration::from_secs_f64(
143            self.left_to_take as f64
144                / self.sample_rate().get() as f64
145                / self.channels().get() as f64,
146        ))
147    }
148}
149
150#[derive(Debug)]
151struct ReplayQueue {
152    inner: ArrayQueue<Vec<Sample>>,
153    normal_chunk_len: usize,
154    /// The last chunk in the queue may be smaller than
155    /// the normal chunk size. This is always equal to the
156    /// size of the last element in the queue.
157    /// (so normally chunk_size)
158    last_chunk: Mutex<Vec<Sample>>,
159}
160
161impl ReplayQueue {
162    fn new(queue_len: usize, chunk_size: usize) -> Self {
163        Self {
164            inner: ArrayQueue::new(queue_len),
165            normal_chunk_len: chunk_size,
166            last_chunk: Mutex::new(Vec::new()),
167        }
168    }
169    /// Returns the length in samples
170    fn len(&self) -> usize {
171        self.inner.len().saturating_sub(1) * self.normal_chunk_len
172            + self
173                .last_chunk
174                .lock()
175                .expect("Self::push_last can not poison this lock")
176                .len()
177    }
178
179    fn pop(&self) -> Option<Vec<Sample>> {
180        self.inner.pop() // removes element that was inserted first
181    }
182
183    fn push_last(&self, mut samples: Vec<Sample>) {
184        let mut last_chunk = self
185            .last_chunk
186            .lock()
187            .expect("Self::len can not poison this lock");
188        std::mem::swap(&mut *last_chunk, &mut samples);
189    }
190
191    fn push_normal(&self, samples: Vec<Sample>) {
192        let _pushed_out_of_ringbuf = self.inner.force_push(samples);
193    }
194}
195
196pub struct ProcessBuffer<const N: usize, S, F>
197where
198    S: Source + Sized,
199    F: FnMut(&mut [Sample; N]),
200{
201    inner: S,
202    callback: F,
203    /// Buffer used for both input and output.
204    buffer: [Sample; N],
205    /// Next already processed sample is at this index
206    /// in buffer.
207    ///
208    /// If this is equal to the length of the buffer we have no more samples and
209    /// we must get new ones and process them
210    next: usize,
211}
212
213impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
214where
215    S: Source + Sized,
216    F: FnMut(&mut [Sample; N]),
217{
218    type Item = Sample;
219
220    fn next(&mut self) -> Option<Self::Item> {
221        self.next += 1;
222        if self.next < self.buffer.len() {
223            let sample = self.buffer[self.next];
224            return Some(sample);
225        }
226
227        for sample in &mut self.buffer {
228            *sample = self.inner.next()?
229        }
230        (self.callback)(&mut self.buffer);
231
232        self.next = 0;
233        Some(self.buffer[0])
234    }
235
236    fn size_hint(&self) -> (usize, Option<usize>) {
237        self.inner.size_hint()
238    }
239}
240
241impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
242where
243    S: Source + Sized,
244    F: FnMut(&mut [Sample; N]),
245{
246    fn current_span_len(&self) -> Option<usize> {
247        None
248    }
249
250    fn channels(&self) -> rodio::ChannelCount {
251        self.inner.channels()
252    }
253
254    fn sample_rate(&self) -> rodio::SampleRate {
255        self.inner.sample_rate()
256    }
257
258    fn total_duration(&self) -> Option<std::time::Duration> {
259        self.inner.total_duration()
260    }
261}
262
263pub struct InspectBuffer<const N: usize, S, F>
264where
265    S: Source + Sized,
266    F: FnMut(&[Sample; N]),
267{
268    inner: S,
269    callback: F,
270    /// Stores already emitted samples, once its full we call the callback.
271    buffer: [Sample; N],
272    /// Next free element in buffer. If this is equal to the buffer length
273    /// we have no more free lements.
274    free: usize,
275}
276
277impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
278where
279    S: Source + Sized,
280    F: FnMut(&[Sample; N]),
281{
282    type Item = Sample;
283
284    fn next(&mut self) -> Option<Self::Item> {
285        let Some(sample) = self.inner.next() else {
286            return None;
287        };
288
289        self.buffer[self.free] = sample;
290        self.free += 1;
291
292        if self.free == self.buffer.len() {
293            (self.callback)(&self.buffer);
294            self.free = 0
295        }
296
297        Some(sample)
298    }
299
300    fn size_hint(&self) -> (usize, Option<usize>) {
301        self.inner.size_hint()
302    }
303}
304
305impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
306where
307    S: Source + Sized,
308    F: FnMut(&[Sample; N]),
309{
310    fn current_span_len(&self) -> Option<usize> {
311        None
312    }
313
314    fn channels(&self) -> rodio::ChannelCount {
315        self.inner.channels()
316    }
317
318    fn sample_rate(&self) -> rodio::SampleRate {
319        self.inner.sample_rate()
320    }
321
322    fn total_duration(&self) -> Option<std::time::Duration> {
323        self.inner.total_duration()
324    }
325}
326
327#[derive(Debug)]
328pub struct Replayable<S: Source> {
329    inner: S,
330    buffer: Vec<Sample>,
331    chunk_size: usize,
332    tx: Arc<ReplayQueue>,
333    is_active: Arc<AtomicBool>,
334}
335
336impl<S: Source> Iterator for Replayable<S> {
337    type Item = Sample;
338
339    fn next(&mut self) -> Option<Self::Item> {
340        if let Some(sample) = self.inner.next() {
341            self.buffer.push(sample);
342            // If the buffer is full send it
343            if self.buffer.len() == self.chunk_size {
344                self.tx.push_normal(std::mem::take(&mut self.buffer));
345            }
346            Some(sample)
347        } else {
348            let last_chunk = std::mem::take(&mut self.buffer);
349            self.tx.push_last(last_chunk);
350            self.is_active.store(false, Ordering::Relaxed);
351            None
352        }
353    }
354
355    fn size_hint(&self) -> (usize, Option<usize>) {
356        self.inner.size_hint()
357    }
358}
359
360impl<S: Source> Source for Replayable<S> {
361    fn current_span_len(&self) -> Option<usize> {
362        self.inner.current_span_len()
363    }
364
365    fn channels(&self) -> ChannelCount {
366        self.inner.channels()
367    }
368
369    fn sample_rate(&self) -> SampleRate {
370        self.inner.sample_rate()
371    }
372
373    fn total_duration(&self) -> Option<Duration> {
374        self.inner.total_duration()
375    }
376}
377
378#[derive(Debug)]
379pub struct Replay {
380    rx: Arc<ReplayQueue>,
381    buffer: std::vec::IntoIter<Sample>,
382    sleep_duration: Duration,
383    sample_rate: SampleRate,
384    channel_count: ChannelCount,
385    source_is_active: Arc<AtomicBool>,
386}
387
388impl Replay {
389    pub fn source_is_active(&self) -> bool {
390        // - source could return None and not drop
391        // - source could be dropped before returning None
392        self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
393    }
394
395    /// Duration of what is in the buffer and can be returned without blocking.
396    pub fn duration_ready(&self) -> Duration {
397        let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
398
399        let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
400        Duration::from_secs_f64(seconds_queued)
401    }
402
403    /// Number of samples in the buffer and can be returned without blocking.
404    pub fn samples_ready(&self) -> usize {
405        self.rx.len() + self.buffer.len()
406    }
407}
408
409impl Iterator for Replay {
410    type Item = Sample;
411
412    fn next(&mut self) -> Option<Self::Item> {
413        if let Some(sample) = self.buffer.next() {
414            return Some(sample);
415        }
416
417        loop {
418            if let Some(new_buffer) = self.rx.pop() {
419                self.buffer = new_buffer.into_iter();
420                return self.buffer.next();
421            }
422
423            if !self.source_is_active() {
424                return None;
425            }
426
427            // The queue does not support blocking on a next item. We want this queue as it
428            // is quite fast and provides a fixed size. We know how many samples are in a
429            // buffer so if we do not get one now we must be getting one after `sleep_duration`.
430            std::thread::sleep(self.sleep_duration);
431        }
432    }
433
434    fn size_hint(&self) -> (usize, Option<usize>) {
435        ((self.rx.len() + self.buffer.len()), None)
436    }
437}
438
439impl Source for Replay {
440    fn current_span_len(&self) -> Option<usize> {
441        None // source is not compatible with spans
442    }
443
444    fn channels(&self) -> ChannelCount {
445        self.channel_count
446    }
447
448    fn sample_rate(&self) -> SampleRate {
449        self.sample_rate
450    }
451
452    fn total_duration(&self) -> Option<Duration> {
453        None
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use rodio::{nz, static_buffer::StaticSamplesBuffer};
460
461    use super::*;
462
463    const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
464
465    fn test_source() -> StaticSamplesBuffer {
466        StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
467    }
468
469    mod process_buffer {
470        use super::*;
471
472        #[test]
473        fn callback_gets_all_samples() {
474            let input = test_source();
475
476            let _ = input
477                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
478                .count();
479        }
480        #[test]
481        fn callback_modifies_yielded() {
482            let input = test_source();
483
484            let yielded: Vec<_> = input
485                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
486                    for sample in buffer {
487                        *sample += 1.0;
488                    }
489                })
490                .collect();
491            assert_eq!(
492                yielded,
493                SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
494            )
495        }
496        #[test]
497        fn source_truncates_to_whole_buffers() {
498            let input = test_source();
499
500            let yielded = input
501                .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
502                .count();
503            assert_eq!(yielded, 3)
504        }
505    }
506
507    mod inspect_buffer {
508        use super::*;
509
510        #[test]
511        fn callback_gets_all_samples() {
512            let input = test_source();
513
514            let _ = input
515                .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
516                .count();
517        }
518        #[test]
519        fn source_does_not_truncate() {
520            let input = test_source();
521
522            let yielded = input
523                .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
524                .count();
525            assert_eq!(yielded, SAMPLES.len())
526        }
527    }
528
529    mod instant_replay {
530        use super::*;
531
532        #[test]
533        fn continues_after_history() {
534            let input = test_source();
535
536            let (mut replay, mut source) = input
537                .replayable(Duration::from_secs(3))
538                .expect("longer than 100ms");
539
540            source.by_ref().take(3).count();
541            let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
542            assert_eq!(&yielded, &SAMPLES[0..3],);
543
544            source.count();
545            let yielded: Vec<Sample> = replay.collect();
546            assert_eq!(&yielded, &SAMPLES[3..5],);
547        }
548
549        #[test]
550        fn keeps_only_latest() {
551            let input = test_source();
552
553            let (mut replay, mut source) = input
554                .replayable(Duration::from_secs(2))
555                .expect("longer than 100ms");
556
557            source.by_ref().take(5).count(); // get all items but do not end the source
558            let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
559            assert_eq!(&yielded, &SAMPLES[3..5]);
560            source.count(); // exhaust source
561            assert_eq!(replay.next(), None);
562        }
563
564        #[test]
565        fn keeps_correct_amount_of_seconds() {
566            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
567
568            let (replay, mut source) = input
569                .replayable(Duration::from_secs(2))
570                .expect("longer than 100ms");
571
572            // exhaust but do not yet end source
573            source.by_ref().take(40_000).count();
574
575            // take all samples we can without blocking
576            let ready = replay.samples_ready();
577            let n_yielded = replay.take_samples(ready).count();
578
579            let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
580            let margin = 16_000 / 10; // 100ms
581            assert!(n_yielded as u32 >= max - margin);
582        }
583
584        #[test]
585        fn samples_ready() {
586            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
587            let (mut replay, source) = input
588                .replayable(Duration::from_secs(2))
589                .expect("longer than 100ms");
590            assert_eq!(replay.by_ref().samples_ready(), 0);
591
592            source.take(8000).count(); // half a second
593            let margin = 16_000 / 10; // 100ms
594            let ready = replay.samples_ready();
595            assert!(ready >= 8000 - margin);
596        }
597    }
598}