rodio_ext.rs

  1use std::{
  2    f32,
  3    sync::{
  4        Arc, Mutex,
  5        atomic::{AtomicBool, Ordering},
  6    },
  7    time::Duration,
  8};
  9
 10use crossbeam::queue::ArrayQueue;
 11use denoise::{Denoiser, DenoiserError};
 12use rodio::{ChannelCount, Sample, SampleRate, Source, source::UniformSourceIterator};
 13
 14#[derive(Debug, thiserror::Error)]
 15#[error("Replay duration is too short must be >= 100ms")]
 16pub struct ReplayDurationTooShort;
 17
 18pub trait RodioExt: Source + Sized {
 19    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 20    where
 21        F: FnMut(&mut [Sample; N]);
 22    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 23    where
 24        F: FnMut(&[Sample; N]);
 25    fn replayable(
 26        self,
 27        duration: Duration,
 28    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
 29    fn take_samples(self, n: usize) -> TakeSamples<Self>;
 30    fn denoise(self) -> Result<Denoiser<Self>, DenoiserError>;
 31    fn constant_params(
 32        self,
 33        channel_count: ChannelCount,
 34        sample_rate: SampleRate,
 35    ) -> UniformSourceIterator<Self>;
 36    fn suspicious_stereo_to_mono(self) -> ToMono<Self>;
 37}
 38
 39impl<S: Source> RodioExt for S {
 40    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
 41    where
 42        F: FnMut(&mut [Sample; N]),
 43    {
 44        ProcessBuffer {
 45            inner: self,
 46            callback,
 47            buffer: [0.0; N],
 48            next: N,
 49        }
 50    }
 51    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
 52    where
 53        F: FnMut(&[Sample; N]),
 54    {
 55        InspectBuffer {
 56            inner: self,
 57            callback,
 58            buffer: [0.0; N],
 59            free: 0,
 60        }
 61    }
 62    /// Maintains a live replay with a history of at least `duration` seconds.
 63    ///
 64    /// Note:
 65    /// History can be 100ms longer if the source drops before or while the
 66    /// replay is being read
 67    ///
 68    /// # Errors
 69    /// If duration is smaller then 100ms
 70    fn replayable(
 71        self,
 72        duration: Duration,
 73    ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
 74        if duration < Duration::from_millis(100) {
 75            return Err(ReplayDurationTooShort);
 76        }
 77
 78        let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
 79        let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
 80        let samples_to_queue =
 81            (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
 82
 83        let chunk_size =
 84            (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
 85        let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
 86
 87        let is_active = Arc::new(AtomicBool::new(true));
 88        let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
 89        Ok((
 90            Replay {
 91                rx: Arc::clone(&queue),
 92                buffer: Vec::new().into_iter(),
 93                sleep_duration: duration / 2,
 94                sample_rate: self.sample_rate(),
 95                channel_count: self.channels(),
 96                source_is_active: is_active.clone(),
 97            },
 98            Replayable {
 99                tx: queue,
100                inner: self,
101                buffer: Vec::with_capacity(chunk_size),
102                chunk_size,
103                is_active,
104            },
105        ))
106    }
107    fn take_samples(self, n: usize) -> TakeSamples<S> {
108        TakeSamples {
109            inner: self,
110            left_to_take: n,
111        }
112    }
113    fn denoise(self) -> Result<Denoiser<Self>, DenoiserError> {
114        let res = Denoiser::try_new(self);
115        log::info!("result of new: {res:?}");
116        res
117    }
118    fn constant_params(
119        self,
120        channel_count: ChannelCount,
121        sample_rate: SampleRate,
122    ) -> UniformSourceIterator<Self> {
123        UniformSourceIterator::new(self, channel_count, sample_rate)
124    }
125    fn suspicious_stereo_to_mono(self) -> ToMono<Self> {
126        ToMono {
127            input_channel_count: self.channels(),
128            inner: self,
129            mean: f32::EPSILON * 3.0,
130        }
131    }
132}
133
134/// constant source, only works on a single span
135pub struct ToMono<S> {
136    inner: S,
137    input_channel_count: ChannelCount,
138    /// running mean of second channel 'volume'
139    mean: f32,
140}
141impl<S> ToMono<S> {
142    fn real_stereo(&self) -> bool {
143        dbg!(self.mean);
144        dbg!(self.mean >= f32::EPSILON * 3.0)
145    }
146
147    fn update_mean(&mut self, second_channel: Sample) {
148        const HISTORY: f32 = 500.0;
149        self.mean *= (HISTORY - 1.0) / HISTORY;
150        self.mean += second_channel.abs() / HISTORY;
151    }
152}
153
154impl<S: Source> Source for ToMono<S> {
155    fn current_span_len(&self) -> Option<usize> {
156        None
157    }
158
159    fn channels(&self) -> ChannelCount {
160        rodio::nz!(1)
161    }
162
163    fn sample_rate(&self) -> SampleRate {
164        self.inner.sample_rate()
165    }
166
167    fn total_duration(&self) -> Option<Duration> {
168        self.inner.total_duration()
169    }
170}
171
172impl<S: Source> Iterator for ToMono<S> {
173    type Item = Sample;
174
175    fn next(&mut self) -> Option<Self::Item> {
176        match self.input_channel_count.get() {
177            1 => self.next(),
178            2 => {
179                let first_channel = self.next()?;
180                let second_channel = self.next()?;
181                self.update_mean(second_channel);
182
183                if self.real_stereo() {
184                    Some((first_channel + second_channel) / 2.0)
185                } else {
186                    Some(first_channel)
187                }
188            }
189            _ => todo!("unsupported channel count"),
190        }
191    }
192}
193
194/// constant source, only works on a single span
195pub struct TakeSamples<S> {
196    inner: S,
197    left_to_take: usize,
198}
199
200impl<S: Source> Iterator for TakeSamples<S> {
201    type Item = Sample;
202
203    fn next(&mut self) -> Option<Self::Item> {
204        if self.left_to_take == 0 {
205            None
206        } else {
207            self.left_to_take -= 1;
208            self.inner.next()
209        }
210    }
211
212    fn size_hint(&self) -> (usize, Option<usize>) {
213        (0, Some(self.left_to_take))
214    }
215}
216
217impl<S: Source> Source for TakeSamples<S> {
218    fn current_span_len(&self) -> Option<usize> {
219        None // does not support spans
220    }
221
222    fn channels(&self) -> ChannelCount {
223        self.inner.channels()
224    }
225
226    fn sample_rate(&self) -> SampleRate {
227        self.inner.sample_rate()
228    }
229
230    fn total_duration(&self) -> Option<Duration> {
231        Some(Duration::from_secs_f64(
232            self.left_to_take as f64
233                / self.sample_rate().get() as f64
234                / self.channels().get() as f64,
235        ))
236    }
237}
238
239/// constant source, only works on a single span
240#[derive(Debug)]
241struct ReplayQueue {
242    inner: ArrayQueue<Vec<Sample>>,
243    normal_chunk_len: usize,
244    /// The last chunk in the queue may be smaller then
245    /// the normal chunk size. This is always equal to the
246    /// size of the last element in the queue.
247    /// (so normally chunk_size)
248    last_chunk: Mutex<Vec<Sample>>,
249}
250
251impl ReplayQueue {
252    fn new(queue_len: usize, chunk_size: usize) -> Self {
253        Self {
254            inner: ArrayQueue::new(queue_len),
255            normal_chunk_len: chunk_size,
256            last_chunk: Mutex::new(Vec::new()),
257        }
258    }
259    /// Returns the length in samples
260    fn len(&self) -> usize {
261        self.inner.len().saturating_sub(1) * self.normal_chunk_len
262            + self
263                .last_chunk
264                .lock()
265                .expect("Self::push_last can not poison this lock")
266                .len()
267    }
268
269    fn pop(&self) -> Option<Vec<Sample>> {
270        self.inner.pop() // removes element that was inserted first
271    }
272
273    fn push_last(&self, mut samples: Vec<Sample>) {
274        let mut last_chunk = self
275            .last_chunk
276            .lock()
277            .expect("Self::len can not poison this lock");
278        std::mem::swap(&mut *last_chunk, &mut samples);
279    }
280
281    fn push_normal(&self, samples: Vec<Sample>) {
282        let _pushed_out_of_ringbuf = self.inner.force_push(samples);
283    }
284}
285
286/// constant source, only works on a single span
287pub struct ProcessBuffer<const N: usize, S, F>
288where
289    S: Source + Sized,
290    F: FnMut(&mut [Sample; N]),
291{
292    inner: S,
293    callback: F,
294    /// Buffer used for both input and output.
295    buffer: [Sample; N],
296    /// Next already processed sample is at this index
297    /// in buffer.
298    ///
299    /// If this is equal to the length of the buffer we have no more samples and
300    /// we must get new ones and process them
301    next: usize,
302}
303
304impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
305where
306    S: Source + Sized,
307    F: FnMut(&mut [Sample; N]),
308{
309    type Item = Sample;
310
311    fn next(&mut self) -> Option<Self::Item> {
312        self.next += 1;
313        if self.next < self.buffer.len() {
314            let sample = self.buffer[self.next];
315            return Some(sample);
316        }
317
318        for sample in &mut self.buffer {
319            *sample = self.inner.next()?
320        }
321        (self.callback)(&mut self.buffer);
322
323        self.next = 0;
324        Some(self.buffer[0])
325    }
326
327    fn size_hint(&self) -> (usize, Option<usize>) {
328        self.inner.size_hint()
329    }
330}
331
332impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
333where
334    S: Source + Sized,
335    F: FnMut(&mut [Sample; N]),
336{
337    fn current_span_len(&self) -> Option<usize> {
338        None
339    }
340
341    fn channels(&self) -> rodio::ChannelCount {
342        self.inner.channels()
343    }
344
345    fn sample_rate(&self) -> rodio::SampleRate {
346        self.inner.sample_rate()
347    }
348
349    fn total_duration(&self) -> Option<std::time::Duration> {
350        self.inner.total_duration()
351    }
352}
353
354/// constant source, only works on a single span
355pub struct InspectBuffer<const N: usize, S, F>
356where
357    S: Source + Sized,
358    F: FnMut(&[Sample; N]),
359{
360    inner: S,
361    callback: F,
362    /// Stores already emitted samples, once its full we call the callback.
363    buffer: [Sample; N],
364    /// Next free element in buffer. If this is equal to the buffer length
365    /// we have no more free lements.
366    free: usize,
367}
368
369impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
370where
371    S: Source + Sized,
372    F: FnMut(&[Sample; N]),
373{
374    type Item = Sample;
375
376    fn next(&mut self) -> Option<Self::Item> {
377        let Some(sample) = self.inner.next() else {
378            return None;
379        };
380
381        self.buffer[self.free] = sample;
382        self.free += 1;
383
384        if self.free == self.buffer.len() {
385            (self.callback)(&self.buffer);
386            self.free = 0
387        }
388
389        Some(sample)
390    }
391
392    fn size_hint(&self) -> (usize, Option<usize>) {
393        self.inner.size_hint()
394    }
395}
396
397impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
398where
399    S: Source + Sized,
400    F: FnMut(&[Sample; N]),
401{
402    fn current_span_len(&self) -> Option<usize> {
403        None
404    }
405
406    fn channels(&self) -> rodio::ChannelCount {
407        self.inner.channels()
408    }
409
410    fn sample_rate(&self) -> rodio::SampleRate {
411        self.inner.sample_rate()
412    }
413
414    fn total_duration(&self) -> Option<std::time::Duration> {
415        self.inner.total_duration()
416    }
417}
418
419/// constant source, only works on a single span
420#[derive(Debug)]
421pub struct Replayable<S: Source> {
422    inner: S,
423    buffer: Vec<Sample>,
424    chunk_size: usize,
425    tx: Arc<ReplayQueue>,
426    is_active: Arc<AtomicBool>,
427}
428
429impl<S: Source> Iterator for Replayable<S> {
430    type Item = Sample;
431
432    fn next(&mut self) -> Option<Self::Item> {
433        if let Some(sample) = self.inner.next() {
434            self.buffer.push(sample);
435            // If the buffer is full send it
436            if self.buffer.len() == self.chunk_size {
437                self.tx.push_normal(std::mem::take(&mut self.buffer));
438            }
439            Some(sample)
440        } else {
441            let last_chunk = std::mem::take(&mut self.buffer);
442            self.tx.push_last(last_chunk);
443            self.is_active.store(false, Ordering::Relaxed);
444            None
445        }
446    }
447
448    fn size_hint(&self) -> (usize, Option<usize>) {
449        self.inner.size_hint()
450    }
451}
452
453impl<S: Source> Source for Replayable<S> {
454    fn current_span_len(&self) -> Option<usize> {
455        self.inner.current_span_len()
456    }
457
458    fn channels(&self) -> ChannelCount {
459        self.inner.channels()
460    }
461
462    fn sample_rate(&self) -> SampleRate {
463        self.inner.sample_rate()
464    }
465
466    fn total_duration(&self) -> Option<Duration> {
467        self.inner.total_duration()
468    }
469}
470
471/// constant source, only works on a single span
472#[derive(Debug)]
473pub struct Replay {
474    rx: Arc<ReplayQueue>,
475    buffer: std::vec::IntoIter<Sample>,
476    sleep_duration: Duration,
477    sample_rate: SampleRate,
478    channel_count: ChannelCount,
479    source_is_active: Arc<AtomicBool>,
480}
481
482impl Replay {
483    pub fn source_is_active(&self) -> bool {
484        // - source could return None and not drop
485        // - source could be dropped before returning None
486        self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
487    }
488
489    /// Duration of what is in the buffer and can be returned without blocking.
490    pub fn duration_ready(&self) -> Duration {
491        let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
492
493        let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
494        Duration::from_secs_f64(seconds_queued)
495    }
496
497    /// Number of samples in the buffer and can be returned without blocking.
498    pub fn samples_ready(&self) -> usize {
499        self.rx.len() + self.buffer.len()
500    }
501}
502
503impl Iterator for Replay {
504    type Item = Sample;
505
506    fn next(&mut self) -> Option<Self::Item> {
507        if let Some(sample) = self.buffer.next() {
508            return Some(sample);
509        }
510
511        loop {
512            if let Some(new_buffer) = self.rx.pop() {
513                self.buffer = new_buffer.into_iter();
514                return self.buffer.next();
515            }
516
517            if !self.source_is_active() {
518                return None;
519            }
520
521            // The queue does not support blocking on a next item. We want this queue as it
522            // is quite fast and provides a fixed size. We know how many samples are in a
523            // buffer so if we do not get one now we must be getting one after `sleep_duration`.
524            std::thread::sleep(self.sleep_duration);
525        }
526    }
527
528    fn size_hint(&self) -> (usize, Option<usize>) {
529        ((self.rx.len() + self.buffer.len()), None)
530    }
531}
532
533impl Source for Replay {
534    fn current_span_len(&self) -> Option<usize> {
535        None // source is not compatible with spans
536    }
537
538    fn channels(&self) -> ChannelCount {
539        self.channel_count
540    }
541
542    fn sample_rate(&self) -> SampleRate {
543        self.sample_rate
544    }
545
546    fn total_duration(&self) -> Option<Duration> {
547        None
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use rodio::{nz, static_buffer::StaticSamplesBuffer};
554
555    use super::*;
556
557    const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
558
559    fn test_source() -> StaticSamplesBuffer {
560        StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
561    }
562
563    mod process_buffer {
564        use super::*;
565
566        #[test]
567        fn callback_gets_all_samples() {
568            let input = test_source();
569
570            let _ = input
571                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
572                .count();
573        }
574        #[test]
575        fn callback_modifies_yielded() {
576            let input = test_source();
577
578            let yielded: Vec<_> = input
579                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
580                    for sample in buffer {
581                        *sample += 1.0;
582                    }
583                })
584                .collect();
585            assert_eq!(
586                yielded,
587                SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
588            )
589        }
590        #[test]
591        fn source_truncates_to_whole_buffers() {
592            let input = test_source();
593
594            let yielded = input
595                .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
596                .count();
597            assert_eq!(yielded, 3)
598        }
599    }
600
601    mod inspect_buffer {
602        use super::*;
603
604        #[test]
605        fn callback_gets_all_samples() {
606            let input = test_source();
607
608            let _ = input
609                .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
610                .count();
611        }
612        #[test]
613        fn source_does_not_truncate() {
614            let input = test_source();
615
616            let yielded = input
617                .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
618                .count();
619            assert_eq!(yielded, SAMPLES.len())
620        }
621    }
622
623    mod instant_replay {
624        use super::*;
625
626        #[test]
627        fn continues_after_history() {
628            let input = test_source();
629
630            let (mut replay, mut source) = input
631                .replayable(Duration::from_secs(3))
632                .expect("longer then 100ms");
633
634            source.by_ref().take(3).count();
635            let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
636            assert_eq!(&yielded, &SAMPLES[0..3],);
637
638            source.count();
639            let yielded: Vec<Sample> = replay.collect();
640            assert_eq!(&yielded, &SAMPLES[3..5],);
641        }
642
643        #[test]
644        fn keeps_only_latest() {
645            let input = test_source();
646
647            let (mut replay, mut source) = input
648                .replayable(Duration::from_secs(2))
649                .expect("longer then 100ms");
650
651            source.by_ref().take(5).count(); // get all items but do not end the source
652            let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
653            assert_eq!(&yielded, &SAMPLES[3..5]);
654            source.count(); // exhaust source
655            assert_eq!(replay.next(), None);
656        }
657
658        #[test]
659        fn keeps_correct_amount_of_seconds() {
660            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
661
662            let (replay, mut source) = input
663                .replayable(Duration::from_secs(2))
664                .expect("longer then 100ms");
665
666            // exhaust but do not yet end source
667            source.by_ref().take(40_000).count();
668
669            // take all samples we can without blocking
670            let ready = replay.samples_ready();
671            let n_yielded = replay.take_samples(ready).count();
672
673            let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
674            let margin = 16_000 / 10; // 100ms
675            assert!(n_yielded as u32 >= max - margin);
676        }
677
678        #[test]
679        fn samples_ready() {
680            let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
681            let (mut replay, source) = input
682                .replayable(Duration::from_secs(2))
683                .expect("longer then 100ms");
684            assert_eq!(replay.by_ref().samples_ready(), 0);
685
686            source.take(8000).count(); // half a second
687            let margin = 16_000 / 10; // 100ms
688            let ready = replay.samples_ready();
689            assert!(ready >= 8000 - margin);
690        }
691    }
692}