rodio_ext.rs

  1use std::{num::NonZero, time::Duration};
  2
  3use denoise::{Denoiser, DenoiserError};
  4use log::warn;
  5use rodio::{ChannelCount, Sample, SampleRate, Source, conversions::ChannelCountConverter, nz};
  6
  7use crate::rodio_ext::resample::FixedResampler;
  8pub use replayable::{Replay, ReplayDurationTooShort, Replayable};
  9
 10mod replayable;
 11mod resample;
 12
 13const MAX_CHANNELS: usize = 8;
 14
 15// These all require constant sources (so the span is infinitely long)
 16// this is not guaranteed by rodio however we know it to be true in all our
 17// applications. Rodio desperately needs a constant source concept.
 18pub trait RodioExt: Source + Sized {
 19    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 20    where
 21        F: FnMut(&mut [Sample; N]);
 22    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 23    where
 24        F: FnMut(&[Sample; N]);
 25    fn replayable(
 26        self,
 27        duration: Duration,
 28    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
 29    fn take_samples(self, n: usize) -> TakeSamples<Self>;
 30    fn denoise(self) -> Result<Denoiser<Self>, DenoiserError>;
 31    fn constant_params(
 32        self,
 33        channel_count: ChannelCount,
 34        sample_rate: SampleRate,
 35    ) -> ConstantChannelCount<FixedResampler<Self>>;
 36    fn constant_samplerate(self, sample_rate: SampleRate) -> FixedResampler<Self>;
 37    fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self>;
 38}
 39
 40impl<S: Source> RodioExt for S {
 41    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 42    where
 43        F: FnMut(&mut [Sample; N]),
 44    {
 45        ProcessBuffer {
 46            inner: self,
 47            callback,
 48            buffer: [0.0; N],
 49            next: N,
 50        }
 51    }
 52    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 53    where
 54        F: FnMut(&[Sample; N]),
 55    {
 56        InspectBuffer {
 57            inner: self,
 58            callback,
 59            buffer: [0.0; N],
 60            free: 0,
 61        }
 62    }
 63    /// Maintains a live replay with a history of at least `duration` seconds.
 64    ///
 65    /// Note:
 66    /// History can be 100ms longer if the source drops before or while the
 67    /// replay is being read
 68    ///
 69    /// # Errors
 70    /// If duration is smaller than 100ms
 71    fn replayable(
 72        self,
 73        duration: Duration,
 74    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
 75        replayable::replayable(self, duration)
 76    }
 77    fn take_samples(self, n: usize) -> TakeSamples<S> {
 78        TakeSamples {
 79            inner: self,
 80            left_to_take: n,
 81        }
 82    }
 83    fn denoise(self) -> Result<Denoiser<Self>, DenoiserError> {
 84        let res = Denoiser::try_new(self);
 85        res
 86    }
 87    fn constant_params(
 88        self,
 89        channel_count: ChannelCount,
 90        sample_rate: SampleRate,
 91    ) -> ConstantChannelCount<FixedResampler<Self>> {
 92        ConstantChannelCount::new(self.constant_samplerate(sample_rate), channel_count)
 93    }
 94    fn constant_samplerate(self, sample_rate: SampleRate) -> FixedResampler<Self> {
 95        FixedResampler::new(self, sample_rate)
 96    }
 97    fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self> {
 98        ToMono::new(self)
 99    }
100}
101
102pub struct ConstantChannelCount<S: Source> {
103    inner: ChannelCountConverter<S>,
104    channels: ChannelCount,
105    sample_rate: SampleRate,
106}
107
108impl<S: Source> ConstantChannelCount<S> {
109    fn new(source: S, target_channels: ChannelCount) -> Self {
110        let input_channels = source.channels();
111        let sample_rate = source.sample_rate();
112        let inner = ChannelCountConverter::new(source, input_channels, target_channels);
113        Self {
114            sample_rate,
115            inner,
116            channels: target_channels,
117        }
118    }
119}
120
121impl<S: Source> Iterator for ConstantChannelCount<S> {
122    type Item = rodio::Sample;
123
124    fn next(&mut self) -> Option<Self::Item> {
125        self.inner.next()
126    }
127
128    fn size_hint(&self) -> (usize, Option<usize>) {
129        self.inner.size_hint()
130    }
131}
132
133impl<S: Source> Source for ConstantChannelCount<S> {
134    fn current_span_len(&self) -> Option<usize> {
135        None
136    }
137
138    fn channels(&self) -> ChannelCount {
139        self.channels
140    }
141
142    fn sample_rate(&self) -> SampleRate {
143        self.sample_rate
144    }
145
146    fn total_duration(&self) -> Option<Duration> {
147        None // not supported (not used by us)
148    }
149}
150
151const TYPICAL_NOISE_FLOOR: Sample = 1e-3;
152
153/// constant source, only works on a single span
154pub struct ToMono<S> {
155    inner: S,
156    input_channel_count: ChannelCount,
157    connected_channels: ChannelCount,
158    /// running mean of second channel 'volume'
159    means: [f32; MAX_CHANNELS],
160}
161impl<S: Source> ToMono<S> {
162    fn new(input: S) -> Self {
163        let channels = input
164            .channels()
165            .min(const { NonZero::<u16>::new(MAX_CHANNELS as u16).unwrap() });
166        if channels < input.channels() {
167            warn!("Ignoring input channels {}..", channels.get());
168        }
169
170        Self {
171            connected_channels: channels,
172            input_channel_count: channels,
173            inner: input,
174            means: [TYPICAL_NOISE_FLOOR; MAX_CHANNELS],
175        }
176    }
177}
178
179impl<S: Source> Source for ToMono<S> {
180    fn current_span_len(&self) -> Option<usize> {
181        None
182    }
183
184    fn channels(&self) -> ChannelCount {
185        rodio::nz!(1)
186    }
187
188    fn sample_rate(&self) -> SampleRate {
189        self.inner.sample_rate()
190    }
191
192    fn total_duration(&self) -> Option<Duration> {
193        self.inner.total_duration()
194    }
195}
196
197fn update_mean(mean: &mut f32, sample: Sample) {
198    const HISTORY: f32 = 500.0;
199    *mean *= (HISTORY - 1.0) / HISTORY;
200    *mean += sample.abs() / HISTORY;
201}
202
203impl<S: Source> Iterator for ToMono<S> {
204    type Item = Sample;
205
206    fn next(&mut self) -> Option<Self::Item> {
207        let mut mono_sample = 0f32;
208        let mut active_channels = 0;
209        for channel in 0..self.input_channel_count.get() as usize {
210            let sample = self.inner.next()?;
211            mono_sample += sample;
212
213            update_mean(&mut self.means[channel], sample);
214            if self.means[channel] > TYPICAL_NOISE_FLOOR / 10.0 {
215                active_channels += 1;
216            }
217        }
218        mono_sample /= self.connected_channels.get() as f32;
219        self.connected_channels = NonZero::new(active_channels).unwrap_or(nz!(1));
220
221        Some(mono_sample)
222    }
223}
224
225/// constant source, only works on a single span
226pub struct TakeSamples<S> {
227    inner: S,
228    left_to_take: usize,
229}
230
231impl<S: Source> Iterator for TakeSamples<S> {
232    type Item = Sample;
233
234    fn next(&mut self) -> Option<Self::Item> {
235        if self.left_to_take == 0 {
236            None
237        } else {
238            self.left_to_take -= 1;
239            self.inner.next()
240        }
241    }
242
243    fn size_hint(&self) -> (usize, Option<usize>) {
244        (0, Some(self.left_to_take))
245    }
246}
247
248impl<S: Source> Source for TakeSamples<S> {
249    fn current_span_len(&self) -> Option<usize> {
250        None // does not support spans
251    }
252
253    fn channels(&self) -> ChannelCount {
254        self.inner.channels()
255    }
256
257    fn sample_rate(&self) -> SampleRate {
258        self.inner.sample_rate()
259    }
260
261    fn total_duration(&self) -> Option<Duration> {
262        Some(Duration::from_secs_f64(
263            self.left_to_take as f64
264                / self.sample_rate().get() as f64
265                / self.channels().get() as f64,
266        ))
267    }
268}
269
270/// constant source, only works on a single span
271pub struct ProcessBuffer<const N: usize, S, F>
272where
273    S: Source + Sized,
274    F: FnMut(&mut [Sample; N]),
275{
276    inner: S,
277    callback: F,
278    /// Buffer used for both input and output.
279    buffer: [Sample; N],
280    /// Next already processed sample is at this index
281    /// in buffer.
282    ///
283    /// If this is equal to the length of the buffer we have no more samples and
284    /// we must get new ones and process them
285    next: usize,
286}
287
288impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
289where
290    S: Source + Sized,
291    F: FnMut(&mut [Sample; N]),
292{
293    type Item = Sample;
294
295    fn next(&mut self) -> Option<Self::Item> {
296        self.next += 1;
297        if self.next < self.buffer.len() {
298            let sample = self.buffer[self.next];
299            return Some(sample);
300        }
301
302        for sample in &mut self.buffer {
303            *sample = self.inner.next()?
304        }
305        (self.callback)(&mut self.buffer);
306
307        self.next = 0;
308        Some(self.buffer[0])
309    }
310
311    fn size_hint(&self) -> (usize, Option<usize>) {
312        self.inner.size_hint()
313    }
314}
315
316impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
317where
318    S: Source + Sized,
319    F: FnMut(&mut [Sample; N]),
320{
321    fn current_span_len(&self) -> Option<usize> {
322        None
323    }
324
325    fn channels(&self) -> rodio::ChannelCount {
326        self.inner.channels()
327    }
328
329    fn sample_rate(&self) -> rodio::SampleRate {
330        self.inner.sample_rate()
331    }
332
333    fn total_duration(&self) -> Option<std::time::Duration> {
334        self.inner.total_duration()
335    }
336}
337
338/// constant source, only works on a single span
339pub struct InspectBuffer<const N: usize, S, F>
340where
341    S: Source + Sized,
342    F: FnMut(&[Sample; N]),
343{
344    inner: S,
345    callback: F,
346    /// Stores already emitted samples, once its full we call the callback.
347    buffer: [Sample; N],
348    /// Next free element in buffer. If this is equal to the buffer length
349    /// we have no more free lements.
350    free: usize,
351}
352
353impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
354where
355    S: Source + Sized,
356    F: FnMut(&[Sample; N]),
357{
358    type Item = Sample;
359
360    fn next(&mut self) -> Option<Self::Item> {
361        let Some(sample) = self.inner.next() else {
362            return None;
363        };
364
365        self.buffer[self.free] = sample;
366        self.free += 1;
367
368        if self.free == self.buffer.len() {
369            (self.callback)(&self.buffer);
370            self.free = 0
371        }
372
373        Some(sample)
374    }
375
376    fn size_hint(&self) -> (usize, Option<usize>) {
377        self.inner.size_hint()
378    }
379}
380
381impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
382where
383    S: Source + Sized,
384    F: FnMut(&[Sample; N]),
385{
386    fn current_span_len(&self) -> Option<usize> {
387        None
388    }
389
390    fn channels(&self) -> rodio::ChannelCount {
391        self.inner.channels()
392    }
393
394    fn sample_rate(&self) -> rodio::SampleRate {
395        self.inner.sample_rate()
396    }
397
398    fn total_duration(&self) -> Option<std::time::Duration> {
399        self.inner.total_duration()
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use rodio::{nz, static_buffer::StaticSamplesBuffer};
406
407    use super::*;
408
409    pub const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
410
411    pub fn test_source() -> StaticSamplesBuffer {
412        StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
413    }
414
415    mod process_buffer {
416        use super::*;
417
418        #[test]
419        fn callback_gets_all_samples() {
420            let input = test_source();
421
422            let _ = input
423                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
424                .count();
425        }
426        #[test]
427        fn callback_modifies_yielded() {
428            let input = test_source();
429
430            let yielded: Vec<_> = input
431                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
432                    for sample in buffer {
433                        *sample += 1.0;
434                    }
435                })
436                .collect();
437            assert_eq!(
438                yielded,
439                SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
440            )
441        }
442        #[test]
443        fn source_truncates_to_whole_buffers() {
444            let input = test_source();
445
446            let yielded = input
447                .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
448                .count();
449            assert_eq!(yielded, 3)
450        }
451    }
452
453    mod inspect_buffer {
454        use super::*;
455
456        #[test]
457        fn callback_gets_all_samples() {
458            let input = test_source();
459
460            let _ = input
461                .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
462                .count();
463        }
464        #[test]
465        fn source_does_not_truncate() {
466            let input = test_source();
467
468            let yielded = input
469                .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
470                .count();
471            assert_eq!(yielded, SAMPLES.len())
472        }
473    }
474}