rodio_ext.rs

  1use std::{
  2    f32,
  3    num::NonZero,
  4    sync::{
  5        Arc, Mutex,
  6        atomic::{AtomicBool, Ordering},
  7    },
  8    time::Duration,
  9};
 10
 11use crossbeam::queue::ArrayQueue;
 12use denoise::{Denoiser, DenoiserError};
 13use log::warn;
 14use rodio::{
 15    ChannelCount, Sample, SampleRate, Source, conversions::SampleRateConverter, nz,
 16    source::UniformSourceIterator,
 17};
 18
 19const MAX_CHANNELS: usize = 8;
 20
 21#[derive(Debug, thiserror::Error)]
 22#[error("Replay duration is too short must be >= 100ms")]
 23pub struct ReplayDurationTooShort;
 24
 25// These all require constant sources (so the span is infinitely long)
 26// this is not guaranteed by rodio however we know it to be true in all our
 27// applications. Rodio desperately needs a constant source concept.
 28pub trait RodioExt: Source + Sized {
 29    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 30    where
 31        F: FnMut(&mut [Sample; N]);
 32    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 33    where
 34        F: FnMut(&[Sample; N]);
 35    fn replayable(
 36        self,
 37        duration: Duration,
 38    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
 39    fn take_samples(self, n: usize) -> TakeSamples<Self>;
 40    fn denoise(self) -> Result<Denoiser<Self>, DenoiserError>;
 41    fn constant_params(
 42        self,
 43        channel_count: ChannelCount,
 44        sample_rate: SampleRate,
 45    ) -> UniformSourceIterator<Self>;
 46    fn constant_samplerate(self, sample_rate: SampleRate) -> ConstantSampleRate<Self>;
 47    fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self>;
 48}
 49
 50impl<S: Source> RodioExt for S {
 51    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 52    where
 53        F: FnMut(&mut [Sample; N]),
 54    {
 55        ProcessBuffer {
 56            inner: self,
 57            callback,
 58            buffer: [0.0; N],
 59            next: N,
 60        }
 61    }
 62    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 63    where
 64        F: FnMut(&[Sample; N]),
 65    {
 66        InspectBuffer {
 67            inner: self,
 68            callback,
 69            buffer: [0.0; N],
 70            free: 0,
 71        }
 72    }
 73    /// Maintains a live replay with a history of at least `duration` seconds.
 74    ///
 75    /// Note:
 76    /// History can be 100ms longer if the source drops before or while the
 77    /// replay is being read
 78    ///
 79    /// # Errors
 80    /// If duration is smaller than 100ms
 81    fn replayable(
 82        self,
 83        duration: Duration,
 84    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
 85        if duration < Duration::from_millis(100) {
 86            return Err(ReplayDurationTooShort);
 87        }
 88
 89        let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
 90        let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
 91        let samples_to_queue =
 92            (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
 93
 94        let chunk_size =
 95            (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
 96        let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
 97
 98        let is_active = Arc::new(AtomicBool::new(true));
 99        let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
100        Ok((
101            Replay {
102                rx: Arc::clone(&queue),
103                buffer: Vec::new().into_iter(),
104                sleep_duration: duration / 2,
105                sample_rate: self.sample_rate(),
106                channel_count: self.channels(),
107                source_is_active: is_active.clone(),
108            },
109            Replayable {
110                tx: queue,
111                inner: self,
112                buffer: Vec::with_capacity(chunk_size),
113                chunk_size,
114                is_active,
115            },
116        ))
117    }
118    fn take_samples(self, n: usize) -> TakeSamples<S> {
119        TakeSamples {
120            inner: self,
121            left_to_take: n,
122        }
123    }
124    fn denoise(self) -> Result<Denoiser<Self>, DenoiserError> {
125        let res = Denoiser::try_new(self);
126        res
127    }
128    fn constant_params(
129        self,
130        channel_count: ChannelCount,
131        sample_rate: SampleRate,
132    ) -> UniformSourceIterator<Self> {
133        UniformSourceIterator::new(self, channel_count, sample_rate)
134    }
135    fn constant_samplerate(self, sample_rate: SampleRate) -> ConstantSampleRate<Self> {
136        ConstantSampleRate::new(self, sample_rate)
137    }
138    fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self> {
139        ToMono::new(self)
140    }
141}
142
143pub struct ConstantSampleRate<S: Source> {
144    inner: SampleRateConverter<S>,
145    channels: ChannelCount,
146    sample_rate: SampleRate,
147}
148
149impl<S: Source> ConstantSampleRate<S> {
150    fn new(source: S, target_rate: SampleRate) -> Self {
151        let input_sample_rate = source.sample_rate();
152        let channels = source.channels();
153        let inner = SampleRateConverter::new(source, input_sample_rate, target_rate, channels);
154        Self {
155            inner,
156            channels,
157            sample_rate: target_rate,
158        }
159    }
160}
161
162impl<S: Source> Iterator for ConstantSampleRate<S> {
163    type Item = rodio::Sample;
164
165    fn next(&mut self) -> Option<Self::Item> {
166        self.inner.next()
167    }
168
169    fn size_hint(&self) -> (usize, Option<usize>) {
170        self.inner.size_hint()
171    }
172}
173
174impl<S: Source> Source for ConstantSampleRate<S> {
175    fn current_span_len(&self) -> Option<usize> {
176        None
177    }
178
179    fn channels(&self) -> ChannelCount {
180        self.channels
181    }
182
183    fn sample_rate(&self) -> SampleRate {
184        self.sample_rate
185    }
186
187    fn total_duration(&self) -> Option<Duration> {
188        None // not supported (not used by us)
189    }
190}
191
192const TYPICAL_NOISE_FLOOR: Sample = 1e-3;
193
194/// constant source, only works on a single span
195pub struct ToMono<S> {
196    inner: S,
197    input_channel_count: ChannelCount,
198    connected_channels: ChannelCount,
199    /// running mean of second channel 'volume'
200    means: [f32; MAX_CHANNELS],
201}
202impl<S: Source> ToMono<S> {
203    fn new(input: S) -> Self {
204        let channels = input
205            .channels()
206            .min(const { NonZero::<u16>::new(MAX_CHANNELS as u16).unwrap() });
207        if channels < input.channels() {
208            warn!("Ignoring input channels {}..", channels.get());
209        }
210
211        Self {
212            connected_channels: channels,
213            input_channel_count: channels,
214            inner: input,
215            means: [TYPICAL_NOISE_FLOOR; MAX_CHANNELS],
216        }
217    }
218}
219
220impl<S: Source> Source for ToMono<S> {
221    fn current_span_len(&self) -> Option<usize> {
222        None
223    }
224
225    fn channels(&self) -> ChannelCount {
226        rodio::nz!(1)
227    }
228
229    fn sample_rate(&self) -> SampleRate {
230        self.inner.sample_rate()
231    }
232
233    fn total_duration(&self) -> Option<Duration> {
234        self.inner.total_duration()
235    }
236}
237
238fn update_mean(mean: &mut f32, sample: Sample) {
239    const HISTORY: f32 = 500.0;
240    *mean *= (HISTORY - 1.0) / HISTORY;
241    *mean += sample.abs() / HISTORY;
242}
243
244impl<S: Source> Iterator for ToMono<S> {
245    type Item = Sample;
246
247    fn next(&mut self) -> Option<Self::Item> {
248        let mut mono_sample = 0f32;
249        let mut active_channels = 0;
250        for channel in 0..self.input_channel_count.get() as usize {
251            let sample = self.inner.next()?;
252            mono_sample += sample;
253
254            update_mean(&mut self.means[channel], sample);
255            if self.means[channel] > TYPICAL_NOISE_FLOOR / 10.0 {
256                active_channels += 1;
257            }
258        }
259        mono_sample /= self.connected_channels.get() as f32;
260        self.connected_channels = NonZero::new(active_channels).unwrap_or(nz!(1));
261
262        Some(mono_sample)
263    }
264}
265
266/// constant source, only works on a single span
267pub struct TakeSamples<S> {
268    inner: S,
269    left_to_take: usize,
270}
271
272impl<S: Source> Iterator for TakeSamples<S> {
273    type Item = Sample;
274
275    fn next(&mut self) -> Option<Self::Item> {
276        if self.left_to_take == 0 {
277            None
278        } else {
279            self.left_to_take -= 1;
280            self.inner.next()
281        }
282    }
283
284    fn size_hint(&self) -> (usize, Option<usize>) {
285        (0, Some(self.left_to_take))
286    }
287}
288
289impl<S: Source> Source for TakeSamples<S> {
290    fn current_span_len(&self) -> Option<usize> {
291        None // does not support spans
292    }
293
294    fn channels(&self) -> ChannelCount {
295        self.inner.channels()
296    }
297
298    fn sample_rate(&self) -> SampleRate {
299        self.inner.sample_rate()
300    }
301
302    fn total_duration(&self) -> Option<Duration> {
303        Some(Duration::from_secs_f64(
304            self.left_to_take as f64
305                / self.sample_rate().get() as f64
306                / self.channels().get() as f64,
307        ))
308    }
309}
310
311/// constant source, only works on a single span
312#[derive(Debug)]
313struct ReplayQueue {
314    inner: ArrayQueue<Vec<Sample>>,
315    normal_chunk_len: usize,
316    /// The last chunk in the queue may be smaller than
317    /// the normal chunk size. This is always equal to the
318    /// size of the last element in the queue.
319    /// (so normally chunk_size)
320    last_chunk: Mutex<Vec<Sample>>,
321}
322
323impl ReplayQueue {
324    fn new(queue_len: usize, chunk_size: usize) -> Self {
325        Self {
326            inner: ArrayQueue::new(queue_len),
327            normal_chunk_len: chunk_size,
328            last_chunk: Mutex::new(Vec::new()),
329        }
330    }
331    /// Returns the length in samples
332    fn len(&self) -> usize {
333        self.inner.len().saturating_sub(1) * self.normal_chunk_len
334            + self
335                .last_chunk
336                .lock()
337                .expect("Self::push_last can not poison this lock")
338                .len()
339    }
340
341    fn pop(&self) -> Option<Vec<Sample>> {
342        self.inner.pop() // removes element that was inserted first
343    }
344
345    fn push_last(&self, mut samples: Vec<Sample>) {
346        let mut last_chunk = self
347            .last_chunk
348            .lock()
349            .expect("Self::len can not poison this lock");
350        std::mem::swap(&mut *last_chunk, &mut samples);
351    }
352
353    fn push_normal(&self, samples: Vec<Sample>) {
354        let _pushed_out_of_ringbuf = self.inner.force_push(samples);
355    }
356}
357
358/// constant source, only works on a single span
359pub struct ProcessBuffer<const N: usize, S, F>
360where
361    S: Source + Sized,
362    F: FnMut(&mut [Sample; N]),
363{
364    inner: S,
365    callback: F,
366    /// Buffer used for both input and output.
367    buffer: [Sample; N],
368    /// Next already processed sample is at this index
369    /// in buffer.
370    ///
371    /// If this is equal to the length of the buffer we have no more samples and
372    /// we must get new ones and process them
373    next: usize,
374}
375
376impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
377where
378    S: Source + Sized,
379    F: FnMut(&mut [Sample; N]),
380{
381    type Item = Sample;
382
383    fn next(&mut self) -> Option<Self::Item> {
384        self.next += 1;
385        if self.next < self.buffer.len() {
386            let sample = self.buffer[self.next];
387            return Some(sample);
388        }
389
390        for sample in &mut self.buffer {
391            *sample = self.inner.next()?
392        }
393        (self.callback)(&mut self.buffer);
394
395        self.next = 0;
396        Some(self.buffer[0])
397    }
398
399    fn size_hint(&self) -> (usize, Option<usize>) {
400        self.inner.size_hint()
401    }
402}
403
404impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
405where
406    S: Source + Sized,
407    F: FnMut(&mut [Sample; N]),
408{
409    fn current_span_len(&self) -> Option<usize> {
410        None
411    }
412
413    fn channels(&self) -> rodio::ChannelCount {
414        self.inner.channels()
415    }
416
417    fn sample_rate(&self) -> rodio::SampleRate {
418        self.inner.sample_rate()
419    }
420
421    fn total_duration(&self) -> Option<std::time::Duration> {
422        self.inner.total_duration()
423    }
424}
425
426/// constant source, only works on a single span
427pub struct InspectBuffer<const N: usize, S, F>
428where
429    S: Source + Sized,
430    F: FnMut(&[Sample; N]),
431{
432    inner: S,
433    callback: F,
434    /// Stores already emitted samples, once its full we call the callback.
435    buffer: [Sample; N],
436    /// Next free element in buffer. If this is equal to the buffer length
437    /// we have no more free lements.
438    free: usize,
439}
440
441impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
442where
443    S: Source + Sized,
444    F: FnMut(&[Sample; N]),
445{
446    type Item = Sample;
447
448    fn next(&mut self) -> Option<Self::Item> {
449        let Some(sample) = self.inner.next() else {
450            return None;
451        };
452
453        self.buffer[self.free] = sample;
454        self.free += 1;
455
456        if self.free == self.buffer.len() {
457            (self.callback)(&self.buffer);
458            self.free = 0
459        }
460
461        Some(sample)
462    }
463
464    fn size_hint(&self) -> (usize, Option<usize>) {
465        self.inner.size_hint()
466    }
467}
468
469impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
470where
471    S: Source + Sized,
472    F: FnMut(&[Sample; N]),
473{
474    fn current_span_len(&self) -> Option<usize> {
475        None
476    }
477
478    fn channels(&self) -> rodio::ChannelCount {
479        self.inner.channels()
480    }
481
482    fn sample_rate(&self) -> rodio::SampleRate {
483        self.inner.sample_rate()
484    }
485
486    fn total_duration(&self) -> Option<std::time::Duration> {
487        self.inner.total_duration()
488    }
489}
490
491/// constant source, only works on a single span
492#[derive(Debug)]
493pub struct Replayable<S: Source> {
494    inner: S,
495    buffer: Vec<Sample>,
496    chunk_size: usize,
497    tx: Arc<ReplayQueue>,
498    is_active: Arc<AtomicBool>,
499}
500
501impl<S: Source> Iterator for Replayable<S> {
502    type Item = Sample;
503
504    fn next(&mut self) -> Option<Self::Item> {
505        if let Some(sample) = self.inner.next() {
506            self.buffer.push(sample);
507            // If the buffer is full send it
508            if self.buffer.len() == self.chunk_size {
509                self.tx.push_normal(std::mem::take(&mut self.buffer));
510            }
511            Some(sample)
512        } else {
513            let last_chunk = std::mem::take(&mut self.buffer);
514            self.tx.push_last(last_chunk);
515            self.is_active.store(false, Ordering::Relaxed);
516            None
517        }
518    }
519
520    fn size_hint(&self) -> (usize, Option<usize>) {
521        self.inner.size_hint()
522    }
523}
524
525impl<S: Source> Source for Replayable<S> {
526    fn current_span_len(&self) -> Option<usize> {
527        self.inner.current_span_len()
528    }
529
530    fn channels(&self) -> ChannelCount {
531        self.inner.channels()
532    }
533
534    fn sample_rate(&self) -> SampleRate {
535        self.inner.sample_rate()
536    }
537
538    fn total_duration(&self) -> Option<Duration> {
539        self.inner.total_duration()
540    }
541}
542
543/// constant source, only works on a single span
544#[derive(Debug)]
545pub struct Replay {
546    rx: Arc<ReplayQueue>,
547    buffer: std::vec::IntoIter<Sample>,
548    sleep_duration: Duration,
549    sample_rate: SampleRate,
550    channel_count: ChannelCount,
551    source_is_active: Arc<AtomicBool>,
552}
553
554impl Replay {
555    pub fn source_is_active(&self) -> bool {
556        // - source could return None and not drop
557        // - source could be dropped before returning None
558        self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
559    }
560
561    /// Duration of what is in the buffer and can be returned without blocking.
562    pub fn duration_ready(&self) -> Duration {
563        let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
564
565        let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
566        Duration::from_secs_f64(seconds_queued)
567    }
568
569    /// Number of samples in the buffer and can be returned without blocking.
570    pub fn samples_ready(&self) -> usize {
571        self.rx.len() + self.buffer.len()
572    }
573}
574
575impl Iterator for Replay {
576    type Item = Sample;
577
578    fn next(&mut self) -> Option<Self::Item> {
579        if let Some(sample) = self.buffer.next() {
580            return Some(sample);
581        }
582
583        loop {
584            if let Some(new_buffer) = self.rx.pop() {
585                self.buffer = new_buffer.into_iter();
586                return self.buffer.next();
587            }
588
589            if !self.source_is_active() {
590                return None;
591            }
592
593            // The queue does not support blocking on a next item. We want this queue as it
594            // is quite fast and provides a fixed size. We know how many samples are in a
595            // buffer so if we do not get one now we must be getting one after `sleep_duration`.
596            std::thread::sleep(self.sleep_duration);
597        }
598    }
599
600    fn size_hint(&self) -> (usize, Option<usize>) {
601        ((self.rx.len() + self.buffer.len()), None)
602    }
603}
604
605impl Source for Replay {
606    fn current_span_len(&self) -> Option<usize> {
607        None // source is not compatible with spans
608    }
609
610    fn channels(&self) -> ChannelCount {
611        self.channel_count
612    }
613
614    fn sample_rate(&self) -> SampleRate {
615        self.sample_rate
616    }
617
618    fn total_duration(&self) -> Option<Duration> {
619        None
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use rodio::{nz, static_buffer::StaticSamplesBuffer};
626
627    use super::*;
628
629    const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
630
631    fn test_source() -> StaticSamplesBuffer {
632        StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
633    }
634
635    mod process_buffer {
636        use super::*;
637
638        #[test]
639        fn callback_gets_all_samples() {
640            let input = test_source();
641
642            let _ = input
643                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
644                .count();
645        }
646        #[test]
647        fn callback_modifies_yielded() {
648            let input = test_source();
649
650            let yielded: Vec<_> = input
651                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
652                    for sample in buffer {
653                        *sample += 1.0;
654                    }
655                })
656                .collect();
657            assert_eq!(
658                yielded,
659                SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
660            )
661        }
662        #[test]
663        fn source_truncates_to_whole_buffers() {
664            let input = test_source();
665
666            let yielded = input
667                .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
668                .count();
669            assert_eq!(yielded, 3)
670        }
671    }
672
673    mod inspect_buffer {
674        use super::*;
675
676        #[test]
677        fn callback_gets_all_samples() {
678            let input = test_source();
679
680            let _ = input
681                .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
682                .count();
683        }
684        #[test]
685        fn source_does_not_truncate() {
686            let input = test_source();
687
688            let yielded = input
689                .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
690                .count();
691            assert_eq!(yielded, SAMPLES.len())
692        }
693    }
694
695    mod instant_replay {
696        use super::*;
697
698        #[test]
699        fn continues_after_history() {
700            let input = test_source();
701
702            let (mut replay, mut source) = input
703                .replayable(Duration::from_secs(3))
704                .expect("longer than 100ms");
705
706            source.by_ref().take(3).count();
707            let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
708            assert_eq!(&yielded, &SAMPLES[0..3],);
709
710            source.count();
711            let yielded: Vec<Sample> = replay.collect();
712            assert_eq!(&yielded, &SAMPLES[3..5],);
713        }
714
715        #[test]
716        fn keeps_only_latest() {
717            let input = test_source();
718
719            let (mut replay, mut source) = input
720                .replayable(Duration::from_secs(2))
721                .expect("longer than 100ms");
722
723            source.by_ref().take(5).count(); // get all items but do not end the source
724            let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
725            assert_eq!(&yielded, &SAMPLES[3..5]);
726            source.count(); // exhaust source
727            assert_eq!(replay.next(), None);
728        }
729
730        #[test]
731        fn keeps_correct_amount_of_seconds() {
732            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
733
734            let (replay, mut source) = input
735                .replayable(Duration::from_secs(2))
736                .expect("longer than 100ms");
737
738            // exhaust but do not yet end source
739            source.by_ref().take(40_000).count();
740
741            // take all samples we can without blocking
742            let ready = replay.samples_ready();
743            let n_yielded = replay.take_samples(ready).count();
744
745            let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
746            let margin = 16_000 / 10; // 100ms
747            assert!(n_yielded as u32 >= max - margin);
748        }
749
750        #[test]
751        fn samples_ready() {
752            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
753            let (mut replay, source) = input
754                .replayable(Duration::from_secs(2))
755                .expect("longer than 100ms");
756            assert_eq!(replay.by_ref().samples_ready(), 0);
757
758            source.take(8000).count(); // half a second
759            let margin = 16_000 / 10; // 100ms
760            let ready = replay.samples_ready();
761            assert!(ready >= 8000 - margin);
762        }
763    }
764}