rodio_ext.rs

  1use std::{
  2    sync::{
  3        Arc, Mutex,
  4        atomic::{AtomicBool, Ordering},
  5    },
  6    time::Duration,
  7};
  8
  9use crossbeam::queue::ArrayQueue;
 10use rodio::{ChannelCount, Sample, SampleRate, Source};
 11
 12#[derive(Debug)]
 13pub struct ReplayDurationTooShort;
 14
 15pub trait RodioExt: Source + Sized {
 16    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 17    where
 18        F: FnMut(&mut [Sample; N]);
 19    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 20    where
 21        F: FnMut(&[Sample; N]);
 22    fn replayable(
 23        self,
 24        duration: Duration,
 25    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
 26    fn take_samples(self, n: usize) -> TakeSamples<Self>;
 27}
 28
 29impl<S: Source> RodioExt for S {
 30    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 31    where
 32        F: FnMut(&mut [Sample; N]),
 33    {
 34        ProcessBuffer {
 35            inner: self,
 36            callback,
 37            buffer: [0.0; N],
 38            next: N,
 39        }
 40    }
 41    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 42    where
 43        F: FnMut(&[Sample; N]),
 44    {
 45        InspectBuffer {
 46            inner: self,
 47            callback,
 48            buffer: [0.0; N],
 49            free: 0,
 50        }
 51    }
 52    /// Maintains a live replay with a history of at least `duration` seconds.
 53    ///
 54    /// Note:
 55    /// History can be 100ms longer if the source drops before or while the
 56    /// replay is being read
 57    ///
 58    /// # Errors
 59    /// If duration is smaller then 100ms
 60    fn replayable(
 61        self,
 62        duration: Duration,
 63    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
 64        if duration < Duration::from_millis(100) {
 65            return Err(ReplayDurationTooShort);
 66        }
 67
 68        let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
 69        let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
 70        let samples_to_queue =
 71            (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
 72
 73        let chunk_size =
 74            (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
 75        let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
 76
 77        let is_active = Arc::new(AtomicBool::new(true));
 78        let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
 79        Ok((
 80            Replay {
 81                rx: Arc::clone(&queue),
 82                buffer: Vec::new().into_iter(),
 83                sleep_duration: duration / 2,
 84                sample_rate: self.sample_rate(),
 85                channel_count: self.channels(),
 86                source_is_active: is_active.clone(),
 87            },
 88            Replayable {
 89                tx: queue,
 90                inner: self,
 91                buffer: Vec::with_capacity(chunk_size),
 92                chunk_size,
 93                is_active,
 94            },
 95        ))
 96    }
 97    fn take_samples(self, n: usize) -> TakeSamples<S> {
 98        TakeSamples {
 99            inner: self,
100            left_to_take: n,
101        }
102    }
103}
104
105pub struct TakeSamples<S> {
106    inner: S,
107    left_to_take: usize,
108}
109
110impl<S: Source> Iterator for TakeSamples<S> {
111    type Item = Sample;
112
113    fn next(&mut self) -> Option<Self::Item> {
114        if self.left_to_take == 0 {
115            None
116        } else {
117            self.left_to_take -= 1;
118            self.inner.next()
119        }
120    }
121
122    fn size_hint(&self) -> (usize, Option<usize>) {
123        (0, Some(self.left_to_take))
124    }
125}
126
127impl<S: Source> Source for TakeSamples<S> {
128    fn current_span_len(&self) -> Option<usize> {
129        None // does not support spans
130    }
131
132    fn channels(&self) -> ChannelCount {
133        self.inner.channels()
134    }
135
136    fn sample_rate(&self) -> SampleRate {
137        self.inner.sample_rate()
138    }
139
140    fn total_duration(&self) -> Option<Duration> {
141        Some(Duration::from_secs_f64(
142            self.left_to_take as f64
143                / self.sample_rate().get() as f64
144                / self.channels().get() as f64,
145        ))
146    }
147}
148
149#[derive(Debug)]
150struct ReplayQueue {
151    inner: ArrayQueue<Vec<Sample>>,
152    normal_chunk_len: usize,
153    /// The last chunk in the queue may be smaller then
154    /// the normal chunk size. This is always equal to the
155    /// size of the last element in the queue.
156    /// (so normally chunk_size)
157    last_chunk: Mutex<Vec<Sample>>,
158}
159
160impl ReplayQueue {
161    fn new(queue_len: usize, chunk_size: usize) -> Self {
162        Self {
163            inner: ArrayQueue::new(queue_len),
164            normal_chunk_len: chunk_size,
165            last_chunk: Mutex::new(Vec::new()),
166        }
167    }
168    /// Returns the length in samples
169    fn len(&self) -> usize {
170        self.inner.len().saturating_sub(1) * self.normal_chunk_len
171            + self
172                .last_chunk
173                .lock()
174                .expect("Self::push_last can not poison this lock")
175                .len()
176    }
177
178    fn pop(&self) -> Option<Vec<Sample>> {
179        self.inner.pop() // removes element that was inserted first
180    }
181
182    fn push_last(&self, mut samples: Vec<Sample>) {
183        let mut last_chunk = self
184            .last_chunk
185            .lock()
186            .expect("Self::len can not poison this lock");
187        std::mem::swap(&mut *last_chunk, &mut samples);
188    }
189
190    fn push_normal(&self, samples: Vec<Sample>) {
191        let _pushed_out_of_ringbuf = self.inner.force_push(samples);
192    }
193}
194
195pub struct ProcessBuffer<const N: usize, S, F>
196where
197    S: Source + Sized,
198    F: FnMut(&mut [Sample; N]),
199{
200    inner: S,
201    callback: F,
202    /// Buffer used for both input and output.
203    buffer: [Sample; N],
204    /// Next already processed sample is at this index
205    /// in buffer.
206    ///
207    /// If this is equal to the length of the buffer we have no more samples and
208    /// we must get new ones and process them
209    next: usize,
210}
211
212impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
213where
214    S: Source + Sized,
215    F: FnMut(&mut [Sample; N]),
216{
217    type Item = Sample;
218
219    fn next(&mut self) -> Option<Self::Item> {
220        self.next += 1;
221        if self.next < self.buffer.len() {
222            let sample = self.buffer[self.next];
223            return Some(sample);
224        }
225
226        for sample in &mut self.buffer {
227            *sample = self.inner.next()?
228        }
229        (self.callback)(&mut self.buffer);
230
231        self.next = 0;
232        Some(self.buffer[0])
233    }
234
235    fn size_hint(&self) -> (usize, Option<usize>) {
236        self.inner.size_hint()
237    }
238}
239
240impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
241where
242    S: Source + Sized,
243    F: FnMut(&mut [Sample; N]),
244{
245    fn current_span_len(&self) -> Option<usize> {
246        None
247    }
248
249    fn channels(&self) -> rodio::ChannelCount {
250        self.inner.channels()
251    }
252
253    fn sample_rate(&self) -> rodio::SampleRate {
254        self.inner.sample_rate()
255    }
256
257    fn total_duration(&self) -> Option<std::time::Duration> {
258        self.inner.total_duration()
259    }
260}
261
262pub struct InspectBuffer<const N: usize, S, F>
263where
264    S: Source + Sized,
265    F: FnMut(&[Sample; N]),
266{
267    inner: S,
268    callback: F,
269    /// Stores already emitted samples, once its full we call the callback.
270    buffer: [Sample; N],
271    /// Next free element in buffer. If this is equal to the buffer length
272    /// we have no more free lements.
273    free: usize,
274}
275
276impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
277where
278    S: Source + Sized,
279    F: FnMut(&[Sample; N]),
280{
281    type Item = Sample;
282
283    fn next(&mut self) -> Option<Self::Item> {
284        let Some(sample) = self.inner.next() else {
285            return None;
286        };
287
288        self.buffer[self.free] = sample;
289        self.free += 1;
290
291        if self.free == self.buffer.len() {
292            (self.callback)(&self.buffer);
293            self.free = 0
294        }
295
296        Some(sample)
297    }
298
299    fn size_hint(&self) -> (usize, Option<usize>) {
300        self.inner.size_hint()
301    }
302}
303
304impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
305where
306    S: Source + Sized,
307    F: FnMut(&[Sample; N]),
308{
309    fn current_span_len(&self) -> Option<usize> {
310        None
311    }
312
313    fn channels(&self) -> rodio::ChannelCount {
314        self.inner.channels()
315    }
316
317    fn sample_rate(&self) -> rodio::SampleRate {
318        self.inner.sample_rate()
319    }
320
321    fn total_duration(&self) -> Option<std::time::Duration> {
322        self.inner.total_duration()
323    }
324}
325
326#[derive(Debug)]
327pub struct Replayable<S: Source> {
328    inner: S,
329    buffer: Vec<Sample>,
330    chunk_size: usize,
331    tx: Arc<ReplayQueue>,
332    is_active: Arc<AtomicBool>,
333}
334
335impl<S: Source> Iterator for Replayable<S> {
336    type Item = Sample;
337
338    fn next(&mut self) -> Option<Self::Item> {
339        if let Some(sample) = self.inner.next() {
340            self.buffer.push(sample);
341            if self.buffer.len() == self.chunk_size {
342                self.tx.push_normal(std::mem::take(&mut self.buffer));
343            }
344            Some(sample)
345        } else {
346            let last_chunk = std::mem::take(&mut self.buffer);
347            self.tx.push_last(last_chunk);
348            self.is_active.store(false, Ordering::Relaxed);
349            None
350        }
351    }
352
353    fn size_hint(&self) -> (usize, Option<usize>) {
354        self.inner.size_hint()
355    }
356}
357
358impl<S: Source> Source for Replayable<S> {
359    fn current_span_len(&self) -> Option<usize> {
360        self.inner.current_span_len()
361    }
362
363    fn channels(&self) -> ChannelCount {
364        self.inner.channels()
365    }
366
367    fn sample_rate(&self) -> SampleRate {
368        self.inner.sample_rate()
369    }
370
371    fn total_duration(&self) -> Option<Duration> {
372        self.inner.total_duration()
373    }
374}
375
376#[derive(Debug)]
377pub struct Replay {
378    rx: Arc<ReplayQueue>,
379    buffer: std::vec::IntoIter<Sample>,
380    sleep_duration: Duration,
381    sample_rate: SampleRate,
382    channel_count: ChannelCount,
383    source_is_active: Arc<AtomicBool>,
384}
385
386impl Replay {
387    pub fn source_is_active(&self) -> bool {
388        // - source could return None and not drop
389        // - source could be dropped before returning None
390        self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
391    }
392
393    /// Duration of what is in the buffer and can be returned without blocking.
394    pub fn duration_ready(&self) -> Duration {
395        let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
396
397        let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
398        Duration::from_secs_f64(seconds_queued)
399    }
400
401    /// Number of samples in the buffer and can be returned without blocking.
402    pub fn samples_ready(&self) -> usize {
403        self.rx.len() + self.buffer.len()
404    }
405}
406
407impl Iterator for Replay {
408    type Item = Sample;
409
410    fn next(&mut self) -> Option<Self::Item> {
411        if let Some(sample) = self.buffer.next() {
412            return Some(sample);
413        }
414
415        loop {
416            if let Some(new_buffer) = self.rx.pop() {
417                self.buffer = new_buffer.into_iter();
418                return self.buffer.next();
419            }
420
421            if !self.source_is_active() {
422                return None;
423            }
424
425            std::thread::sleep(self.sleep_duration);
426        }
427    }
428
429    fn size_hint(&self) -> (usize, Option<usize>) {
430        ((self.rx.len() + self.buffer.len()), None)
431    }
432}
433
434impl Source for Replay {
435    fn current_span_len(&self) -> Option<usize> {
436        None // source is not compatible with spans
437    }
438
439    fn channels(&self) -> ChannelCount {
440        self.channel_count
441    }
442
443    fn sample_rate(&self) -> SampleRate {
444        self.sample_rate
445    }
446
447    fn total_duration(&self) -> Option<Duration> {
448        None
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use rodio::{nz, static_buffer::StaticSamplesBuffer};
455
456    use super::*;
457
458    const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
459
460    fn test_source() -> StaticSamplesBuffer {
461        StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
462    }
463
464    mod process_buffer {
465        use super::*;
466
467        #[test]
468        fn callback_gets_all_samples() {
469            let input = test_source();
470
471            let _ = input
472                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
473                .count();
474        }
475        #[test]
476        fn callback_modifies_yielded() {
477            let input = test_source();
478
479            let yielded: Vec<_> = input
480                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
481                    for sample in buffer {
482                        *sample += 1.0;
483                    }
484                })
485                .collect();
486            assert_eq!(
487                yielded,
488                SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
489            )
490        }
491        #[test]
492        fn source_truncates_to_whole_buffers() {
493            let input = test_source();
494
495            let yielded = input
496                .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
497                .count();
498            assert_eq!(yielded, 3)
499        }
500    }
501
502    mod inspect_buffer {
503        use super::*;
504
505        #[test]
506        fn callback_gets_all_samples() {
507            let input = test_source();
508
509            let _ = input
510                .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
511                .count();
512        }
513        #[test]
514        fn source_does_not_truncate() {
515            let input = test_source();
516
517            let yielded = input
518                .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
519                .count();
520            assert_eq!(yielded, SAMPLES.len())
521        }
522    }
523
524    mod instant_replay {
525        use super::*;
526
527        #[test]
528        fn continues_after_history() {
529            let input = test_source();
530
531            let (mut replay, mut source) = input
532                .replayable(Duration::from_secs(3))
533                .expect("longer then 100ms");
534
535            source.by_ref().take(3).count();
536            let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
537            assert_eq!(&yielded, &SAMPLES[0..3],);
538
539            source.count();
540            let yielded: Vec<Sample> = replay.collect();
541            assert_eq!(&yielded, &SAMPLES[3..5],);
542        }
543
544        #[test]
545        fn keeps_only_latest() {
546            let input = test_source();
547
548            let (mut replay, mut source) = input
549                .replayable(Duration::from_secs(2))
550                .expect("longer then 100ms");
551
552            source.by_ref().take(5).count(); // get all items but do not end the source
553            let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
554            assert_eq!(&yielded, &SAMPLES[3..5]);
555            source.count(); // exhaust source
556            assert_eq!(replay.next(), None);
557        }
558
559        #[test]
560        fn keeps_correct_amount_of_seconds() {
561            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
562
563            let (replay, mut source) = input
564                .replayable(Duration::from_secs(2))
565                .expect("longer then 100ms");
566
567            // exhaust but do not yet end source
568            source.by_ref().take(40_000).count();
569
570            // take all samples we can without blocking
571            let ready = replay.samples_ready();
572            let n_yielded = replay.take_samples(ready).count();
573
574            let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
575            let margin = 16_000 / 10; // 100ms
576            assert!(n_yielded as u32 >= max - margin);
577        }
578
579        #[test]
580        fn samples_ready() {
581            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
582            let (mut replay, source) = input
583                .replayable(Duration::from_secs(2))
584                .expect("longer then 100ms");
585            assert_eq!(replay.by_ref().samples_ready(), 0);
586
587            source.take(8000).count(); // half a second
588            let margin = 16_000 / 10; // 100ms
589            let ready = replay.samples_ready();
590            assert!(ready >= 8000 - margin);
591        }
592    }
593}