rodio_ext.rs

  1use std::{
  2    num::NonZero,
  3    sync::{
  4        Arc, Mutex,
  5        atomic::{AtomicBool, Ordering},
  6    },
  7    time::Duration,
  8};
  9
 10use crossbeam::queue::ArrayQueue;
 11use denoise::{Denoiser, DenoiserError};
 12use log::warn;
 13use rodio::{
 14    ChannelCount, Sample, SampleRate, Source, conversions::SampleRateConverter, nz,
 15    source::UniformSourceIterator,
 16};
 17
 18const MAX_CHANNELS: usize = 8;
 19
 20#[derive(Debug, thiserror::Error)]
 21#[error("Replay duration is too short must be >= 100ms")]
 22pub struct ReplayDurationTooShort;
 23
 24// These all require constant sources (so the span is infinitely long)
 25// this is not guaranteed by rodio however we know it to be true in all our
 26// applications. Rodio desperately needs a constant source concept.
 27pub trait RodioExt: Source + Sized {
 28    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 29    where
 30        F: FnMut(&mut [Sample; N]);
 31    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 32    where
 33        F: FnMut(&[Sample; N]);
 34    fn replayable(
 35        self,
 36        duration: Duration,
 37    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
 38    fn take_samples(self, n: usize) -> TakeSamples<Self>;
 39    fn denoise(self) -> Result<Denoiser<Self>, DenoiserError>;
 40    fn constant_params(
 41        self,
 42        channel_count: ChannelCount,
 43        sample_rate: SampleRate,
 44    ) -> UniformSourceIterator<Self>;
 45    fn constant_samplerate(self, sample_rate: SampleRate) -> ConstantSampleRate<Self>;
 46    fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self>;
 47}
 48
 49impl<S: Source> RodioExt for S {
 50    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 51    where
 52        F: FnMut(&mut [Sample; N]),
 53    {
 54        ProcessBuffer {
 55            inner: self,
 56            callback,
 57            buffer: [0.0; N],
 58            next: N,
 59        }
 60    }
 61    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 62    where
 63        F: FnMut(&[Sample; N]),
 64    {
 65        InspectBuffer {
 66            inner: self,
 67            callback,
 68            buffer: [0.0; N],
 69            free: 0,
 70        }
 71    }
 72    /// Maintains a live replay with a history of at least `duration` seconds.
 73    ///
 74    /// Note:
 75    /// History can be 100ms longer if the source drops before or while the
 76    /// replay is being read
 77    ///
 78    /// # Errors
 79    /// If duration is smaller than 100ms
 80    fn replayable(
 81        self,
 82        duration: Duration,
 83    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
 84        if duration < Duration::from_millis(100) {
 85            return Err(ReplayDurationTooShort);
 86        }
 87
 88        let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
 89        let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
 90        let samples_to_queue =
 91            (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
 92
 93        let chunk_size =
 94            (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
 95        let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
 96
 97        let is_active = Arc::new(AtomicBool::new(true));
 98        let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
 99        Ok((
100            Replay {
101                rx: Arc::clone(&queue),
102                buffer: Vec::new().into_iter(),
103                sleep_duration: duration / 2,
104                sample_rate: self.sample_rate(),
105                channel_count: self.channels(),
106                source_is_active: is_active.clone(),
107            },
108            Replayable {
109                tx: queue,
110                inner: self,
111                buffer: Vec::with_capacity(chunk_size),
112                chunk_size,
113                is_active,
114            },
115        ))
116    }
117    fn take_samples(self, n: usize) -> TakeSamples<S> {
118        TakeSamples {
119            inner: self,
120            left_to_take: n,
121        }
122    }
123    fn denoise(self) -> Result<Denoiser<Self>, DenoiserError> {
124        let res = Denoiser::try_new(self);
125        res
126    }
127    fn constant_params(
128        self,
129        channel_count: ChannelCount,
130        sample_rate: SampleRate,
131    ) -> UniformSourceIterator<Self> {
132        UniformSourceIterator::new(self, channel_count, sample_rate)
133    }
134    fn constant_samplerate(self, sample_rate: SampleRate) -> ConstantSampleRate<Self> {
135        ConstantSampleRate::new(self, sample_rate)
136    }
137    fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self> {
138        ToMono::new(self)
139    }
140}
141
142pub struct ConstantSampleRate<S: Source> {
143    inner: SampleRateConverter<S>,
144    channels: ChannelCount,
145    sample_rate: SampleRate,
146}
147
148impl<S: Source> ConstantSampleRate<S> {
149    fn new(source: S, target_rate: SampleRate) -> Self {
150        let input_sample_rate = source.sample_rate();
151        let channels = source.channels();
152        let inner = SampleRateConverter::new(source, input_sample_rate, target_rate, channels);
153        Self {
154            inner,
155            channels,
156            sample_rate: target_rate,
157        }
158    }
159}
160
161impl<S: Source> Iterator for ConstantSampleRate<S> {
162    type Item = rodio::Sample;
163
164    fn next(&mut self) -> Option<Self::Item> {
165        self.inner.next()
166    }
167
168    fn size_hint(&self) -> (usize, Option<usize>) {
169        self.inner.size_hint()
170    }
171}
172
173impl<S: Source> Source for ConstantSampleRate<S> {
174    fn current_span_len(&self) -> Option<usize> {
175        None
176    }
177
178    fn channels(&self) -> ChannelCount {
179        self.channels
180    }
181
182    fn sample_rate(&self) -> SampleRate {
183        self.sample_rate
184    }
185
186    fn total_duration(&self) -> Option<Duration> {
187        None // not supported (not used by us)
188    }
189}
190
191const TYPICAL_NOISE_FLOOR: Sample = 1e-3;
192
193/// constant source, only works on a single span
194pub struct ToMono<S> {
195    inner: S,
196    input_channel_count: ChannelCount,
197    connected_channels: ChannelCount,
198    /// running mean of second channel 'volume'
199    means: [f32; MAX_CHANNELS],
200}
201impl<S: Source> ToMono<S> {
202    fn new(input: S) -> Self {
203        let channels = input
204            .channels()
205            .min(const { NonZero::<u16>::new(MAX_CHANNELS as u16).unwrap() });
206        if channels < input.channels() {
207            warn!("Ignoring input channels {}..", channels.get());
208        }
209
210        Self {
211            connected_channels: channels,
212            input_channel_count: channels,
213            inner: input,
214            means: [TYPICAL_NOISE_FLOOR; MAX_CHANNELS],
215        }
216    }
217}
218
219impl<S: Source> Source for ToMono<S> {
220    fn current_span_len(&self) -> Option<usize> {
221        None
222    }
223
224    fn channels(&self) -> ChannelCount {
225        rodio::nz!(1)
226    }
227
228    fn sample_rate(&self) -> SampleRate {
229        self.inner.sample_rate()
230    }
231
232    fn total_duration(&self) -> Option<Duration> {
233        self.inner.total_duration()
234    }
235}
236
237fn update_mean(mean: &mut f32, sample: Sample) {
238    const HISTORY: f32 = 500.0;
239    *mean *= (HISTORY - 1.0) / HISTORY;
240    *mean += sample.abs() / HISTORY;
241}
242
243impl<S: Source> Iterator for ToMono<S> {
244    type Item = Sample;
245
246    fn next(&mut self) -> Option<Self::Item> {
247        let mut mono_sample = 0f32;
248        let mut active_channels = 0;
249        for channel in 0..self.input_channel_count.get() as usize {
250            let sample = self.inner.next()?;
251            mono_sample += sample;
252
253            update_mean(&mut self.means[channel], sample);
254            if self.means[channel] > TYPICAL_NOISE_FLOOR / 10.0 {
255                active_channels += 1;
256            }
257        }
258        mono_sample /= self.connected_channels.get() as f32;
259        self.connected_channels = NonZero::new(active_channels).unwrap_or(nz!(1));
260
261        Some(mono_sample)
262    }
263}
264
265/// constant source, only works on a single span
266pub struct TakeSamples<S> {
267    inner: S,
268    left_to_take: usize,
269}
270
271impl<S: Source> Iterator for TakeSamples<S> {
272    type Item = Sample;
273
274    fn next(&mut self) -> Option<Self::Item> {
275        if self.left_to_take == 0 {
276            None
277        } else {
278            self.left_to_take -= 1;
279            self.inner.next()
280        }
281    }
282
283    fn size_hint(&self) -> (usize, Option<usize>) {
284        (0, Some(self.left_to_take))
285    }
286}
287
288impl<S: Source> Source for TakeSamples<S> {
289    fn current_span_len(&self) -> Option<usize> {
290        None // does not support spans
291    }
292
293    fn channels(&self) -> ChannelCount {
294        self.inner.channels()
295    }
296
297    fn sample_rate(&self) -> SampleRate {
298        self.inner.sample_rate()
299    }
300
301    fn total_duration(&self) -> Option<Duration> {
302        Some(Duration::from_secs_f64(
303            self.left_to_take as f64
304                / self.sample_rate().get() as f64
305                / self.channels().get() as f64,
306        ))
307    }
308}
309
310/// constant source, only works on a single span
311#[derive(Debug)]
312struct ReplayQueue {
313    inner: ArrayQueue<Vec<Sample>>,
314    normal_chunk_len: usize,
315    /// The last chunk in the queue may be smaller than
316    /// the normal chunk size. This is always equal to the
317    /// size of the last element in the queue.
318    /// (so normally chunk_size)
319    last_chunk: Mutex<Vec<Sample>>,
320}
321
322impl ReplayQueue {
323    fn new(queue_len: usize, chunk_size: usize) -> Self {
324        Self {
325            inner: ArrayQueue::new(queue_len),
326            normal_chunk_len: chunk_size,
327            last_chunk: Mutex::new(Vec::new()),
328        }
329    }
330    /// Returns the length in samples
331    fn len(&self) -> usize {
332        self.inner.len().saturating_sub(1) * self.normal_chunk_len
333            + self
334                .last_chunk
335                .lock()
336                .expect("Self::push_last can not poison this lock")
337                .len()
338    }
339
340    fn pop(&self) -> Option<Vec<Sample>> {
341        self.inner.pop() // removes element that was inserted first
342    }
343
344    fn push_last(&self, mut samples: Vec<Sample>) {
345        let mut last_chunk = self
346            .last_chunk
347            .lock()
348            .expect("Self::len can not poison this lock");
349        std::mem::swap(&mut *last_chunk, &mut samples);
350    }
351
352    fn push_normal(&self, samples: Vec<Sample>) {
353        let _pushed_out_of_ringbuf = self.inner.force_push(samples);
354    }
355}
356
357/// constant source, only works on a single span
358pub struct ProcessBuffer<const N: usize, S, F>
359where
360    S: Source + Sized,
361    F: FnMut(&mut [Sample; N]),
362{
363    inner: S,
364    callback: F,
365    /// Buffer used for both input and output.
366    buffer: [Sample; N],
367    /// Next already processed sample is at this index
368    /// in buffer.
369    ///
370    /// If this is equal to the length of the buffer we have no more samples and
371    /// we must get new ones and process them
372    next: usize,
373}
374
375impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
376where
377    S: Source + Sized,
378    F: FnMut(&mut [Sample; N]),
379{
380    type Item = Sample;
381
382    fn next(&mut self) -> Option<Self::Item> {
383        self.next += 1;
384        if self.next < self.buffer.len() {
385            let sample = self.buffer[self.next];
386            return Some(sample);
387        }
388
389        for sample in &mut self.buffer {
390            *sample = self.inner.next()?
391        }
392        (self.callback)(&mut self.buffer);
393
394        self.next = 0;
395        Some(self.buffer[0])
396    }
397
398    fn size_hint(&self) -> (usize, Option<usize>) {
399        self.inner.size_hint()
400    }
401}
402
403impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
404where
405    S: Source + Sized,
406    F: FnMut(&mut [Sample; N]),
407{
408    fn current_span_len(&self) -> Option<usize> {
409        None
410    }
411
412    fn channels(&self) -> rodio::ChannelCount {
413        self.inner.channels()
414    }
415
416    fn sample_rate(&self) -> rodio::SampleRate {
417        self.inner.sample_rate()
418    }
419
420    fn total_duration(&self) -> Option<std::time::Duration> {
421        self.inner.total_duration()
422    }
423}
424
425/// constant source, only works on a single span
426pub struct InspectBuffer<const N: usize, S, F>
427where
428    S: Source + Sized,
429    F: FnMut(&[Sample; N]),
430{
431    inner: S,
432    callback: F,
433    /// Stores already emitted samples, once its full we call the callback.
434    buffer: [Sample; N],
435    /// Next free element in buffer. If this is equal to the buffer length
436    /// we have no more free elements.
437    free: usize,
438}
439
440impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
441where
442    S: Source + Sized,
443    F: FnMut(&[Sample; N]),
444{
445    type Item = Sample;
446
447    fn next(&mut self) -> Option<Self::Item> {
448        let Some(sample) = self.inner.next() else {
449            return None;
450        };
451
452        self.buffer[self.free] = sample;
453        self.free += 1;
454
455        if self.free == self.buffer.len() {
456            (self.callback)(&self.buffer);
457            self.free = 0
458        }
459
460        Some(sample)
461    }
462
463    fn size_hint(&self) -> (usize, Option<usize>) {
464        self.inner.size_hint()
465    }
466}
467
468impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
469where
470    S: Source + Sized,
471    F: FnMut(&[Sample; N]),
472{
473    fn current_span_len(&self) -> Option<usize> {
474        None
475    }
476
477    fn channels(&self) -> rodio::ChannelCount {
478        self.inner.channels()
479    }
480
481    fn sample_rate(&self) -> rodio::SampleRate {
482        self.inner.sample_rate()
483    }
484
485    fn total_duration(&self) -> Option<std::time::Duration> {
486        self.inner.total_duration()
487    }
488}
489
490/// constant source, only works on a single span
491#[derive(Debug)]
492pub struct Replayable<S: Source> {
493    inner: S,
494    buffer: Vec<Sample>,
495    chunk_size: usize,
496    tx: Arc<ReplayQueue>,
497    is_active: Arc<AtomicBool>,
498}
499
500impl<S: Source> Iterator for Replayable<S> {
501    type Item = Sample;
502
503    fn next(&mut self) -> Option<Self::Item> {
504        if let Some(sample) = self.inner.next() {
505            self.buffer.push(sample);
506            // If the buffer is full send it
507            if self.buffer.len() == self.chunk_size {
508                self.tx.push_normal(std::mem::take(&mut self.buffer));
509            }
510            Some(sample)
511        } else {
512            let last_chunk = std::mem::take(&mut self.buffer);
513            self.tx.push_last(last_chunk);
514            self.is_active.store(false, Ordering::Relaxed);
515            None
516        }
517    }
518
519    fn size_hint(&self) -> (usize, Option<usize>) {
520        self.inner.size_hint()
521    }
522}
523
524impl<S: Source> Source for Replayable<S> {
525    fn current_span_len(&self) -> Option<usize> {
526        self.inner.current_span_len()
527    }
528
529    fn channels(&self) -> ChannelCount {
530        self.inner.channels()
531    }
532
533    fn sample_rate(&self) -> SampleRate {
534        self.inner.sample_rate()
535    }
536
537    fn total_duration(&self) -> Option<Duration> {
538        self.inner.total_duration()
539    }
540}
541
542/// constant source, only works on a single span
543#[derive(Debug)]
544pub struct Replay {
545    rx: Arc<ReplayQueue>,
546    buffer: std::vec::IntoIter<Sample>,
547    sleep_duration: Duration,
548    sample_rate: SampleRate,
549    channel_count: ChannelCount,
550    source_is_active: Arc<AtomicBool>,
551}
552
553impl Replay {
554    pub fn source_is_active(&self) -> bool {
555        // - source could return None and not drop
556        // - source could be dropped before returning None
557        self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
558    }
559
560    /// Duration of what is in the buffer and can be returned without blocking.
561    pub fn duration_ready(&self) -> Duration {
562        let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
563
564        let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
565        Duration::from_secs_f64(seconds_queued)
566    }
567
568    /// Number of samples in the buffer and can be returned without blocking.
569    pub fn samples_ready(&self) -> usize {
570        self.rx.len() + self.buffer.len()
571    }
572}
573
574impl Iterator for Replay {
575    type Item = Sample;
576
577    fn next(&mut self) -> Option<Self::Item> {
578        if let Some(sample) = self.buffer.next() {
579            return Some(sample);
580        }
581
582        loop {
583            if let Some(new_buffer) = self.rx.pop() {
584                self.buffer = new_buffer.into_iter();
585                return self.buffer.next();
586            }
587
588            if !self.source_is_active() {
589                return None;
590            }
591
592            // The queue does not support blocking on a next item. We want this queue as it
593            // is quite fast and provides a fixed size. We know how many samples are in a
594            // buffer so if we do not get one now we must be getting one after `sleep_duration`.
595            std::thread::sleep(self.sleep_duration);
596        }
597    }
598
599    fn size_hint(&self) -> (usize, Option<usize>) {
600        ((self.rx.len() + self.buffer.len()), None)
601    }
602}
603
604impl Source for Replay {
605    fn current_span_len(&self) -> Option<usize> {
606        None // source is not compatible with spans
607    }
608
609    fn channels(&self) -> ChannelCount {
610        self.channel_count
611    }
612
613    fn sample_rate(&self) -> SampleRate {
614        self.sample_rate
615    }
616
617    fn total_duration(&self) -> Option<Duration> {
618        None
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use rodio::{nz, static_buffer::StaticSamplesBuffer};
625
626    use super::*;
627
628    const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
629
630    fn test_source() -> StaticSamplesBuffer {
631        StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
632    }
633
634    mod process_buffer {
635        use super::*;
636
637        #[test]
638        fn callback_gets_all_samples() {
639            let input = test_source();
640
641            let _ = input
642                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
643                .count();
644        }
645        #[test]
646        fn callback_modifies_yielded() {
647            let input = test_source();
648
649            let yielded: Vec<_> = input
650                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
651                    for sample in buffer {
652                        *sample += 1.0;
653                    }
654                })
655                .collect();
656            assert_eq!(
657                yielded,
658                SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
659            )
660        }
661        #[test]
662        fn source_truncates_to_whole_buffers() {
663            let input = test_source();
664
665            let yielded = input
666                .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
667                .count();
668            assert_eq!(yielded, 3)
669        }
670    }
671
672    mod inspect_buffer {
673        use super::*;
674
675        #[test]
676        fn callback_gets_all_samples() {
677            let input = test_source();
678
679            let _ = input
680                .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
681                .count();
682        }
683        #[test]
684        fn source_does_not_truncate() {
685            let input = test_source();
686
687            let yielded = input
688                .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
689                .count();
690            assert_eq!(yielded, SAMPLES.len())
691        }
692    }
693
694    mod instant_replay {
695        use super::*;
696
697        #[test]
698        fn continues_after_history() {
699            let input = test_source();
700
701            let (mut replay, mut source) = input
702                .replayable(Duration::from_secs(3))
703                .expect("longer than 100ms");
704
705            source.by_ref().take(3).count();
706            let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
707            assert_eq!(&yielded, &SAMPLES[0..3],);
708
709            source.count();
710            let yielded: Vec<Sample> = replay.collect();
711            assert_eq!(&yielded, &SAMPLES[3..5],);
712        }
713
714        #[test]
715        fn keeps_only_latest() {
716            let input = test_source();
717
718            let (mut replay, mut source) = input
719                .replayable(Duration::from_secs(2))
720                .expect("longer than 100ms");
721
722            source.by_ref().take(5).count(); // get all items but do not end the source
723            let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
724            assert_eq!(&yielded, &SAMPLES[3..5]);
725            source.count(); // exhaust source
726            assert_eq!(replay.next(), None);
727        }
728
729        #[test]
730        fn keeps_correct_amount_of_seconds() {
731            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
732
733            let (replay, mut source) = input
734                .replayable(Duration::from_secs(2))
735                .expect("longer than 100ms");
736
737            // exhaust but do not yet end source
738            source.by_ref().take(40_000).count();
739
740            // take all samples we can without blocking
741            let ready = replay.samples_ready();
742            let n_yielded = replay.take_samples(ready).count();
743
744            let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
745            let margin = 16_000 / 10; // 100ms
746            assert!(n_yielded as u32 >= max - margin);
747        }
748
749        #[test]
750        fn samples_ready() {
751            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
752            let (mut replay, source) = input
753                .replayable(Duration::from_secs(2))
754                .expect("longer than 100ms");
755            assert_eq!(replay.by_ref().samples_ready(), 0);
756
757            source.take(8000).count(); // half a second
758            let margin = 16_000 / 10; // 100ms
759            let ready = replay.samples_ready();
760            assert!(ready >= 8000 - margin);
761        }
762    }
763}