rodio_ext.rs

  1use std::{
  2    f32,
  3    sync::{
  4        Arc, Mutex,
  5        atomic::{AtomicBool, Ordering},
  6    },
  7    time::Duration,
  8};
  9
 10use crossbeam::queue::ArrayQueue;
 11use denoise::{Denoiser, DenoiserError};
 12use rodio::{ChannelCount, Sample, SampleRate, Source, source::UniformSourceIterator};
 13
 14#[derive(Debug, thiserror::Error)]
 15#[error("Replay duration is too short must be >= 100ms")]
 16pub struct ReplayDurationTooShort;
 17
 18pub trait RodioExt: Source + Sized {
 19    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 20    where
 21        F: FnMut(&mut [Sample; N]);
 22    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 23    where
 24        F: FnMut(&[Sample; N]);
 25    fn replayable(
 26        self,
 27        duration: Duration,
 28    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
 29    fn take_samples(self, n: usize) -> TakeSamples<Self>;
 30    fn denoise(self) -> Result<Denoiser<Self>, DenoiserError>;
 31    fn constant_params(
 32        self,
 33        channel_count: ChannelCount,
 34        sample_rate: SampleRate,
 35    ) -> UniformSourceIterator<Self>;
 36    fn suspicious_stereo_to_mono(self) -> ToMono<Self>;
 37}
 38
 39impl<S: Source> RodioExt for S {
 40    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 41    where
 42        F: FnMut(&mut [Sample; N]),
 43    {
 44        ProcessBuffer {
 45            inner: self,
 46            callback,
 47            buffer: [0.0; N],
 48            next: N,
 49        }
 50    }
 51    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 52    where
 53        F: FnMut(&[Sample; N]),
 54    {
 55        InspectBuffer {
 56            inner: self,
 57            callback,
 58            buffer: [0.0; N],
 59            free: 0,
 60        }
 61    }
 62    /// Maintains a live replay with a history of at least `duration` seconds.
 63    ///
 64    /// Note:
 65    /// History can be 100ms longer if the source drops before or while the
 66    /// replay is being read
 67    ///
 68    /// # Errors
 69    /// If duration is smaller than 100ms
 70    fn replayable(
 71        self,
 72        duration: Duration,
 73    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
 74        if duration < Duration::from_millis(100) {
 75            return Err(ReplayDurationTooShort);
 76        }
 77
 78        let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
 79        let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
 80        let samples_to_queue =
 81            (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
 82
 83        let chunk_size =
 84            (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
 85        let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
 86
 87        let is_active = Arc::new(AtomicBool::new(true));
 88        let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
 89        Ok((
 90            Replay {
 91                rx: Arc::clone(&queue),
 92                buffer: Vec::new().into_iter(),
 93                sleep_duration: duration / 2,
 94                sample_rate: self.sample_rate(),
 95                channel_count: self.channels(),
 96                source_is_active: is_active.clone(),
 97            },
 98            Replayable {
 99                tx: queue,
100                inner: self,
101                buffer: Vec::with_capacity(chunk_size),
102                chunk_size,
103                is_active,
104            },
105        ))
106    }
107    fn take_samples(self, n: usize) -> TakeSamples<S> {
108        TakeSamples {
109            inner: self,
110            left_to_take: n,
111        }
112    }
113    fn denoise(self) -> Result<Denoiser<Self>, DenoiserError> {
114        let res = Denoiser::try_new(self);
115        res
116    }
117    fn constant_params(
118        self,
119        channel_count: ChannelCount,
120        sample_rate: SampleRate,
121    ) -> UniformSourceIterator<Self> {
122        UniformSourceIterator::new(self, channel_count, sample_rate)
123    }
124    fn suspicious_stereo_to_mono(self) -> ToMono<Self> {
125        ToMono {
126            input_channel_count: self.channels(),
127            inner: self,
128            mean: f32::EPSILON * 3.0,
129        }
130    }
131}
132
133/// constant source, only works on a single span
134pub struct ToMono<S> {
135    inner: S,
136    input_channel_count: ChannelCount,
137    /// running mean of second channel 'volume'
138    mean: f32,
139}
140impl<S> ToMono<S> {
141    fn real_stereo(&self) -> bool {
142        dbg!(self.mean);
143        dbg!(self.mean >= f32::EPSILON * 3.0)
144    }
145
146    fn update_mean(&mut self, second_channel: Sample) {
147        const HISTORY: f32 = 500.0;
148        self.mean *= (HISTORY - 1.0) / HISTORY;
149        self.mean += second_channel.abs() / HISTORY;
150    }
151}
152
153impl<S: Source> Source for ToMono<S> {
154    fn current_span_len(&self) -> Option<usize> {
155        None
156    }
157
158    fn channels(&self) -> ChannelCount {
159        rodio::nz!(1)
160    }
161
162    fn sample_rate(&self) -> SampleRate {
163        self.inner.sample_rate()
164    }
165
166    fn total_duration(&self) -> Option<Duration> {
167        self.inner.total_duration()
168    }
169}
170
171impl<S: Source> Iterator for ToMono<S> {
172    type Item = Sample;
173
174    fn next(&mut self) -> Option<Self::Item> {
175        match self.input_channel_count.get() {
176            1 => self.next(),
177            2 => {
178                let first_channel = self.inner.next().unwrap();
179                let second_channel = self.inner.next()?;
180                self.update_mean(second_channel);
181
182                if self.real_stereo() {
183                    Some((first_channel + second_channel) / 2.0)
184                } else {
185                    Some(first_channel)
186                }
187            }
188            _ => todo!("unsupported channel count"),
189        }
190    }
191}
192
193/// constant source, only works on a single span
194pub struct TakeSamples<S> {
195    inner: S,
196    left_to_take: usize,
197}
198
199impl<S: Source> Iterator for TakeSamples<S> {
200    type Item = Sample;
201
202    fn next(&mut self) -> Option<Self::Item> {
203        if self.left_to_take == 0 {
204            None
205        } else {
206            self.left_to_take -= 1;
207            self.inner.next()
208        }
209    }
210
211    fn size_hint(&self) -> (usize, Option<usize>) {
212        (0, Some(self.left_to_take))
213    }
214}
215
216impl<S: Source> Source for TakeSamples<S> {
217    fn current_span_len(&self) -> Option<usize> {
218        None // does not support spans
219    }
220
221    fn channels(&self) -> ChannelCount {
222        self.inner.channels()
223    }
224
225    fn sample_rate(&self) -> SampleRate {
226        self.inner.sample_rate()
227    }
228
229    fn total_duration(&self) -> Option<Duration> {
230        Some(Duration::from_secs_f64(
231            self.left_to_take as f64
232                / self.sample_rate().get() as f64
233                / self.channels().get() as f64,
234        ))
235    }
236}
237
238/// constant source, only works on a single span
239#[derive(Debug)]
240struct ReplayQueue {
241    inner: ArrayQueue<Vec<Sample>>,
242    normal_chunk_len: usize,
243    /// The last chunk in the queue may be smaller than
244    /// the normal chunk size. This is always equal to the
245    /// size of the last element in the queue.
246    /// (so normally chunk_size)
247    last_chunk: Mutex<Vec<Sample>>,
248}
249
250impl ReplayQueue {
251    fn new(queue_len: usize, chunk_size: usize) -> Self {
252        Self {
253            inner: ArrayQueue::new(queue_len),
254            normal_chunk_len: chunk_size,
255            last_chunk: Mutex::new(Vec::new()),
256        }
257    }
258    /// Returns the length in samples
259    fn len(&self) -> usize {
260        self.inner.len().saturating_sub(1) * self.normal_chunk_len
261            + self
262                .last_chunk
263                .lock()
264                .expect("Self::push_last can not poison this lock")
265                .len()
266    }
267
268    fn pop(&self) -> Option<Vec<Sample>> {
269        self.inner.pop() // removes element that was inserted first
270    }
271
272    fn push_last(&self, mut samples: Vec<Sample>) {
273        let mut last_chunk = self
274            .last_chunk
275            .lock()
276            .expect("Self::len can not poison this lock");
277        std::mem::swap(&mut *last_chunk, &mut samples);
278    }
279
280    fn push_normal(&self, samples: Vec<Sample>) {
281        let _pushed_out_of_ringbuf = self.inner.force_push(samples);
282    }
283}
284
285/// constant source, only works on a single span
286pub struct ProcessBuffer<const N: usize, S, F>
287where
288    S: Source + Sized,
289    F: FnMut(&mut [Sample; N]),
290{
291    inner: S,
292    callback: F,
293    /// Buffer used for both input and output.
294    buffer: [Sample; N],
295    /// Next already processed sample is at this index
296    /// in buffer.
297    ///
298    /// If this is equal to the length of the buffer we have no more samples and
299    /// we must get new ones and process them
300    next: usize,
301}
302
303impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
304where
305    S: Source + Sized,
306    F: FnMut(&mut [Sample; N]),
307{
308    type Item = Sample;
309
310    fn next(&mut self) -> Option<Self::Item> {
311        self.next += 1;
312        if self.next < self.buffer.len() {
313            let sample = self.buffer[self.next];
314            return Some(sample);
315        }
316
317        for sample in &mut self.buffer {
318            *sample = self.inner.next()?
319        }
320        (self.callback)(&mut self.buffer);
321
322        self.next = 0;
323        Some(self.buffer[0])
324    }
325
326    fn size_hint(&self) -> (usize, Option<usize>) {
327        self.inner.size_hint()
328    }
329}
330
331impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
332where
333    S: Source + Sized,
334    F: FnMut(&mut [Sample; N]),
335{
336    fn current_span_len(&self) -> Option<usize> {
337        None
338    }
339
340    fn channels(&self) -> rodio::ChannelCount {
341        self.inner.channels()
342    }
343
344    fn sample_rate(&self) -> rodio::SampleRate {
345        self.inner.sample_rate()
346    }
347
348    fn total_duration(&self) -> Option<std::time::Duration> {
349        self.inner.total_duration()
350    }
351}
352
353/// constant source, only works on a single span
354pub struct InspectBuffer<const N: usize, S, F>
355where
356    S: Source + Sized,
357    F: FnMut(&[Sample; N]),
358{
359    inner: S,
360    callback: F,
361    /// Stores already emitted samples, once its full we call the callback.
362    buffer: [Sample; N],
363    /// Next free element in buffer. If this is equal to the buffer length
364    /// we have no more free lements.
365    free: usize,
366}
367
368impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
369where
370    S: Source + Sized,
371    F: FnMut(&[Sample; N]),
372{
373    type Item = Sample;
374
375    fn next(&mut self) -> Option<Self::Item> {
376        let Some(sample) = self.inner.next() else {
377            return None;
378        };
379
380        self.buffer[self.free] = sample;
381        self.free += 1;
382
383        if self.free == self.buffer.len() {
384            (self.callback)(&self.buffer);
385            self.free = 0
386        }
387
388        Some(sample)
389    }
390
391    fn size_hint(&self) -> (usize, Option<usize>) {
392        self.inner.size_hint()
393    }
394}
395
396impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
397where
398    S: Source + Sized,
399    F: FnMut(&[Sample; N]),
400{
401    fn current_span_len(&self) -> Option<usize> {
402        None
403    }
404
405    fn channels(&self) -> rodio::ChannelCount {
406        self.inner.channels()
407    }
408
409    fn sample_rate(&self) -> rodio::SampleRate {
410        self.inner.sample_rate()
411    }
412
413    fn total_duration(&self) -> Option<std::time::Duration> {
414        self.inner.total_duration()
415    }
416}
417
418/// constant source, only works on a single span
419#[derive(Debug)]
420pub struct Replayable<S: Source> {
421    inner: S,
422    buffer: Vec<Sample>,
423    chunk_size: usize,
424    tx: Arc<ReplayQueue>,
425    is_active: Arc<AtomicBool>,
426}
427
428impl<S: Source> Iterator for Replayable<S> {
429    type Item = Sample;
430
431    fn next(&mut self) -> Option<Self::Item> {
432        if let Some(sample) = self.inner.next() {
433            self.buffer.push(sample);
434            // If the buffer is full send it
435            if self.buffer.len() == self.chunk_size {
436                self.tx.push_normal(std::mem::take(&mut self.buffer));
437            }
438            Some(sample)
439        } else {
440            let last_chunk = std::mem::take(&mut self.buffer);
441            self.tx.push_last(last_chunk);
442            self.is_active.store(false, Ordering::Relaxed);
443            None
444        }
445    }
446
447    fn size_hint(&self) -> (usize, Option<usize>) {
448        self.inner.size_hint()
449    }
450}
451
452impl<S: Source> Source for Replayable<S> {
453    fn current_span_len(&self) -> Option<usize> {
454        self.inner.current_span_len()
455    }
456
457    fn channels(&self) -> ChannelCount {
458        self.inner.channels()
459    }
460
461    fn sample_rate(&self) -> SampleRate {
462        self.inner.sample_rate()
463    }
464
465    fn total_duration(&self) -> Option<Duration> {
466        self.inner.total_duration()
467    }
468}
469
470/// constant source, only works on a single span
471#[derive(Debug)]
472pub struct Replay {
473    rx: Arc<ReplayQueue>,
474    buffer: std::vec::IntoIter<Sample>,
475    sleep_duration: Duration,
476    sample_rate: SampleRate,
477    channel_count: ChannelCount,
478    source_is_active: Arc<AtomicBool>,
479}
480
481impl Replay {
482    pub fn source_is_active(&self) -> bool {
483        // - source could return None and not drop
484        // - source could be dropped before returning None
485        self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
486    }
487
488    /// Duration of what is in the buffer and can be returned without blocking.
489    pub fn duration_ready(&self) -> Duration {
490        let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
491
492        let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
493        Duration::from_secs_f64(seconds_queued)
494    }
495
496    /// Number of samples in the buffer and can be returned without blocking.
497    pub fn samples_ready(&self) -> usize {
498        self.rx.len() + self.buffer.len()
499    }
500}
501
502impl Iterator for Replay {
503    type Item = Sample;
504
505    fn next(&mut self) -> Option<Self::Item> {
506        if let Some(sample) = self.buffer.next() {
507            return Some(sample);
508        }
509
510        loop {
511            if let Some(new_buffer) = self.rx.pop() {
512                self.buffer = new_buffer.into_iter();
513                return self.buffer.next();
514            }
515
516            if !self.source_is_active() {
517                return None;
518            }
519
520            // The queue does not support blocking on a next item. We want this queue as it
521            // is quite fast and provides a fixed size. We know how many samples are in a
522            // buffer so if we do not get one now we must be getting one after `sleep_duration`.
523            std::thread::sleep(self.sleep_duration);
524        }
525    }
526
527    fn size_hint(&self) -> (usize, Option<usize>) {
528        ((self.rx.len() + self.buffer.len()), None)
529    }
530}
531
532impl Source for Replay {
533    fn current_span_len(&self) -> Option<usize> {
534        None // source is not compatible with spans
535    }
536
537    fn channels(&self) -> ChannelCount {
538        self.channel_count
539    }
540
541    fn sample_rate(&self) -> SampleRate {
542        self.sample_rate
543    }
544
545    fn total_duration(&self) -> Option<Duration> {
546        None
547    }
548}
549
550#[cfg(test)]
551mod tests {
552    use rodio::{nz, static_buffer::StaticSamplesBuffer};
553
554    use super::*;
555
556    const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
557
558    fn test_source() -> StaticSamplesBuffer {
559        StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
560    }
561
562    mod process_buffer {
563        use super::*;
564
565        #[test]
566        fn callback_gets_all_samples() {
567            let input = test_source();
568
569            let _ = input
570                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
571                .count();
572        }
573        #[test]
574        fn callback_modifies_yielded() {
575            let input = test_source();
576
577            let yielded: Vec<_> = input
578                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
579                    for sample in buffer {
580                        *sample += 1.0;
581                    }
582                })
583                .collect();
584            assert_eq!(
585                yielded,
586                SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
587            )
588        }
589        #[test]
590        fn source_truncates_to_whole_buffers() {
591            let input = test_source();
592
593            let yielded = input
594                .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
595                .count();
596            assert_eq!(yielded, 3)
597        }
598    }
599
600    mod inspect_buffer {
601        use super::*;
602
603        #[test]
604        fn callback_gets_all_samples() {
605            let input = test_source();
606
607            let _ = input
608                .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
609                .count();
610        }
611        #[test]
612        fn source_does_not_truncate() {
613            let input = test_source();
614
615            let yielded = input
616                .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
617                .count();
618            assert_eq!(yielded, SAMPLES.len())
619        }
620    }
621
622    mod instant_replay {
623        use super::*;
624
625        #[test]
626        fn continues_after_history() {
627            let input = test_source();
628
629            let (mut replay, mut source) = input
630                .replayable(Duration::from_secs(3))
631                .expect("longer than 100ms");
632
633            source.by_ref().take(3).count();
634            let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
635            assert_eq!(&yielded, &SAMPLES[0..3],);
636
637            source.count();
638            let yielded: Vec<Sample> = replay.collect();
639            assert_eq!(&yielded, &SAMPLES[3..5],);
640        }
641
642        #[test]
643        fn keeps_only_latest() {
644            let input = test_source();
645
646            let (mut replay, mut source) = input
647                .replayable(Duration::from_secs(2))
648                .expect("longer than 100ms");
649
650            source.by_ref().take(5).count(); // get all items but do not end the source
651            let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
652            assert_eq!(&yielded, &SAMPLES[3..5]);
653            source.count(); // exhaust source
654            assert_eq!(replay.next(), None);
655        }
656
657        #[test]
658        fn keeps_correct_amount_of_seconds() {
659            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
660
661            let (replay, mut source) = input
662                .replayable(Duration::from_secs(2))
663                .expect("longer than 100ms");
664
665            // exhaust but do not yet end source
666            source.by_ref().take(40_000).count();
667
668            // take all samples we can without blocking
669            let ready = replay.samples_ready();
670            let n_yielded = replay.take_samples(ready).count();
671
672            let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
673            let margin = 16_000 / 10; // 100ms
674            assert!(n_yielded as u32 >= max - margin);
675        }
676
677        #[test]
678        fn samples_ready() {
679            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
680            let (mut replay, source) = input
681                .replayable(Duration::from_secs(2))
682                .expect("longer than 100ms");
683            assert_eq!(replay.by_ref().samples_ready(), 0);
684
685            source.take(8000).count(); // half a second
686            let margin = 16_000 / 10; // 100ms
687            let ready = replay.samples_ready();
688            assert!(ready >= 8000 - margin);
689        }
690    }
691}