1use std::{
2 f32,
3 num::NonZero,
4 sync::{
5 Arc, Mutex,
6 atomic::{AtomicBool, Ordering},
7 },
8 time::Duration,
9};
10
11use crossbeam::queue::ArrayQueue;
12use denoise::{Denoiser, DenoiserError};
13use log::warn;
14use rodio::{
15 ChannelCount, Sample, SampleRate, Source, conversions::SampleRateConverter, nz,
16 source::UniformSourceIterator,
17};
18
19const MAX_CHANNELS: usize = 8;
20
21#[derive(Debug, thiserror::Error)]
22#[error("Replay duration is too short must be >= 100ms")]
23pub struct ReplayDurationTooShort;
24
25// These all require constant sources (so the span is infinitely long)
26// this is not guaranteed by rodio however we know it to be true in all our
27// applications. Rodio desperately needs a constant source concept.
28pub trait RodioExt: Source + Sized {
29 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
30 where
31 F: FnMut(&mut [Sample; N]);
32 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
33 where
34 F: FnMut(&[Sample; N]);
35 fn replayable(
36 self,
37 duration: Duration,
38 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
39 fn take_samples(self, n: usize) -> TakeSamples<Self>;
40 fn denoise(self) -> Result<Denoiser<Self>, DenoiserError>;
41 fn constant_params(
42 self,
43 channel_count: ChannelCount,
44 sample_rate: SampleRate,
45 ) -> UniformSourceIterator<Self>;
46 fn constant_samplerate(self, sample_rate: SampleRate) -> ConstantSampleRate<Self>;
47 fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self>;
48}
49
50impl<S: Source> RodioExt for S {
51 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
52 where
53 F: FnMut(&mut [Sample; N]),
54 {
55 ProcessBuffer {
56 inner: self,
57 callback,
58 buffer: [0.0; N],
59 next: N,
60 }
61 }
62 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
63 where
64 F: FnMut(&[Sample; N]),
65 {
66 InspectBuffer {
67 inner: self,
68 callback,
69 buffer: [0.0; N],
70 free: 0,
71 }
72 }
73 /// Maintains a live replay with a history of at least `duration` seconds.
74 ///
75 /// Note:
76 /// History can be 100ms longer if the source drops before or while the
77 /// replay is being read
78 ///
79 /// # Errors
80 /// If duration is smaller than 100ms
81 fn replayable(
82 self,
83 duration: Duration,
84 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
85 if duration < Duration::from_millis(100) {
86 return Err(ReplayDurationTooShort);
87 }
88
89 let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
90 let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
91 let samples_to_queue =
92 (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
93
94 let chunk_size =
95 (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
96 let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
97
98 let is_active = Arc::new(AtomicBool::new(true));
99 let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
100 Ok((
101 Replay {
102 rx: Arc::clone(&queue),
103 buffer: Vec::new().into_iter(),
104 sleep_duration: duration / 2,
105 sample_rate: self.sample_rate(),
106 channel_count: self.channels(),
107 source_is_active: is_active.clone(),
108 },
109 Replayable {
110 tx: queue,
111 inner: self,
112 buffer: Vec::with_capacity(chunk_size),
113 chunk_size,
114 is_active,
115 },
116 ))
117 }
118 fn take_samples(self, n: usize) -> TakeSamples<S> {
119 TakeSamples {
120 inner: self,
121 left_to_take: n,
122 }
123 }
124 fn denoise(self) -> Result<Denoiser<Self>, DenoiserError> {
125 let res = Denoiser::try_new(self);
126 res
127 }
128 fn constant_params(
129 self,
130 channel_count: ChannelCount,
131 sample_rate: SampleRate,
132 ) -> UniformSourceIterator<Self> {
133 UniformSourceIterator::new(self, channel_count, sample_rate)
134 }
135 fn constant_samplerate(self, sample_rate: SampleRate) -> ConstantSampleRate<Self> {
136 ConstantSampleRate::new(self, sample_rate)
137 }
138 fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self> {
139 ToMono::new(self)
140 }
141}
142
143pub struct ConstantSampleRate<S: Source> {
144 inner: SampleRateConverter<S>,
145 channels: ChannelCount,
146 sample_rate: SampleRate,
147}
148
149impl<S: Source> ConstantSampleRate<S> {
150 fn new(source: S, target_rate: SampleRate) -> Self {
151 let input_sample_rate = source.sample_rate();
152 let channels = source.channels();
153 let inner = SampleRateConverter::new(source, input_sample_rate, target_rate, channels);
154 Self {
155 inner,
156 channels,
157 sample_rate: target_rate,
158 }
159 }
160}
161
162impl<S: Source> Iterator for ConstantSampleRate<S> {
163 type Item = rodio::Sample;
164
165 fn next(&mut self) -> Option<Self::Item> {
166 self.inner.next()
167 }
168
169 fn size_hint(&self) -> (usize, Option<usize>) {
170 self.inner.size_hint()
171 }
172}
173
174impl<S: Source> Source for ConstantSampleRate<S> {
175 fn current_span_len(&self) -> Option<usize> {
176 None
177 }
178
179 fn channels(&self) -> ChannelCount {
180 self.channels
181 }
182
183 fn sample_rate(&self) -> SampleRate {
184 self.sample_rate
185 }
186
187 fn total_duration(&self) -> Option<Duration> {
188 None // not supported (not used by us)
189 }
190}
191
192const TYPICAL_NOISE_FLOOR: Sample = 1e-3;
193
194/// constant source, only works on a single span
195pub struct ToMono<S> {
196 inner: S,
197 input_channel_count: ChannelCount,
198 connected_channels: ChannelCount,
199 /// running mean of second channel 'volume'
200 means: [f32; MAX_CHANNELS],
201}
202impl<S: Source> ToMono<S> {
203 fn new(input: S) -> Self {
204 let channels = input
205 .channels()
206 .min(const { NonZero::<u16>::new(MAX_CHANNELS as u16).unwrap() });
207 if channels < input.channels() {
208 warn!("Ignoring input channels {}..", channels.get());
209 }
210
211 Self {
212 connected_channels: channels,
213 input_channel_count: channels,
214 inner: input,
215 means: [TYPICAL_NOISE_FLOOR; MAX_CHANNELS],
216 }
217 }
218}
219
220impl<S: Source> Source for ToMono<S> {
221 fn current_span_len(&self) -> Option<usize> {
222 None
223 }
224
225 fn channels(&self) -> ChannelCount {
226 rodio::nz!(1)
227 }
228
229 fn sample_rate(&self) -> SampleRate {
230 self.inner.sample_rate()
231 }
232
233 fn total_duration(&self) -> Option<Duration> {
234 self.inner.total_duration()
235 }
236}
237
238fn update_mean(mean: &mut f32, sample: Sample) {
239 const HISTORY: f32 = 500.0;
240 *mean *= (HISTORY - 1.0) / HISTORY;
241 *mean += sample.abs() / HISTORY;
242}
243
244impl<S: Source> Iterator for ToMono<S> {
245 type Item = Sample;
246
247 fn next(&mut self) -> Option<Self::Item> {
248 let mut mono_sample = 0f32;
249 let mut active_channels = 0;
250 for channel in 0..self.input_channel_count.get() as usize {
251 let sample = self.inner.next()?;
252 mono_sample += sample;
253
254 update_mean(&mut self.means[channel], sample);
255 if self.means[channel] > TYPICAL_NOISE_FLOOR / 10.0 {
256 active_channels += 1;
257 }
258 }
259 mono_sample /= self.connected_channels.get() as f32;
260 self.connected_channels = NonZero::new(active_channels).unwrap_or(nz!(1));
261
262 Some(mono_sample)
263 }
264}
265
266/// constant source, only works on a single span
267pub struct TakeSamples<S> {
268 inner: S,
269 left_to_take: usize,
270}
271
272impl<S: Source> Iterator for TakeSamples<S> {
273 type Item = Sample;
274
275 fn next(&mut self) -> Option<Self::Item> {
276 if self.left_to_take == 0 {
277 None
278 } else {
279 self.left_to_take -= 1;
280 self.inner.next()
281 }
282 }
283
284 fn size_hint(&self) -> (usize, Option<usize>) {
285 (0, Some(self.left_to_take))
286 }
287}
288
289impl<S: Source> Source for TakeSamples<S> {
290 fn current_span_len(&self) -> Option<usize> {
291 None // does not support spans
292 }
293
294 fn channels(&self) -> ChannelCount {
295 self.inner.channels()
296 }
297
298 fn sample_rate(&self) -> SampleRate {
299 self.inner.sample_rate()
300 }
301
302 fn total_duration(&self) -> Option<Duration> {
303 Some(Duration::from_secs_f64(
304 self.left_to_take as f64
305 / self.sample_rate().get() as f64
306 / self.channels().get() as f64,
307 ))
308 }
309}
310
311/// constant source, only works on a single span
312#[derive(Debug)]
313struct ReplayQueue {
314 inner: ArrayQueue<Vec<Sample>>,
315 normal_chunk_len: usize,
316 /// The last chunk in the queue may be smaller than
317 /// the normal chunk size. This is always equal to the
318 /// size of the last element in the queue.
319 /// (so normally chunk_size)
320 last_chunk: Mutex<Vec<Sample>>,
321}
322
323impl ReplayQueue {
324 fn new(queue_len: usize, chunk_size: usize) -> Self {
325 Self {
326 inner: ArrayQueue::new(queue_len),
327 normal_chunk_len: chunk_size,
328 last_chunk: Mutex::new(Vec::new()),
329 }
330 }
331 /// Returns the length in samples
332 fn len(&self) -> usize {
333 self.inner.len().saturating_sub(1) * self.normal_chunk_len
334 + self
335 .last_chunk
336 .lock()
337 .expect("Self::push_last can not poison this lock")
338 .len()
339 }
340
341 fn pop(&self) -> Option<Vec<Sample>> {
342 self.inner.pop() // removes element that was inserted first
343 }
344
345 fn push_last(&self, mut samples: Vec<Sample>) {
346 let mut last_chunk = self
347 .last_chunk
348 .lock()
349 .expect("Self::len can not poison this lock");
350 std::mem::swap(&mut *last_chunk, &mut samples);
351 }
352
353 fn push_normal(&self, samples: Vec<Sample>) {
354 let _pushed_out_of_ringbuf = self.inner.force_push(samples);
355 }
356}
357
358/// constant source, only works on a single span
359pub struct ProcessBuffer<const N: usize, S, F>
360where
361 S: Source + Sized,
362 F: FnMut(&mut [Sample; N]),
363{
364 inner: S,
365 callback: F,
366 /// Buffer used for both input and output.
367 buffer: [Sample; N],
368 /// Next already processed sample is at this index
369 /// in buffer.
370 ///
371 /// If this is equal to the length of the buffer we have no more samples and
372 /// we must get new ones and process them
373 next: usize,
374}
375
376impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
377where
378 S: Source + Sized,
379 F: FnMut(&mut [Sample; N]),
380{
381 type Item = Sample;
382
383 fn next(&mut self) -> Option<Self::Item> {
384 self.next += 1;
385 if self.next < self.buffer.len() {
386 let sample = self.buffer[self.next];
387 return Some(sample);
388 }
389
390 for sample in &mut self.buffer {
391 *sample = self.inner.next()?
392 }
393 (self.callback)(&mut self.buffer);
394
395 self.next = 0;
396 Some(self.buffer[0])
397 }
398
399 fn size_hint(&self) -> (usize, Option<usize>) {
400 self.inner.size_hint()
401 }
402}
403
404impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
405where
406 S: Source + Sized,
407 F: FnMut(&mut [Sample; N]),
408{
409 fn current_span_len(&self) -> Option<usize> {
410 None
411 }
412
413 fn channels(&self) -> rodio::ChannelCount {
414 self.inner.channels()
415 }
416
417 fn sample_rate(&self) -> rodio::SampleRate {
418 self.inner.sample_rate()
419 }
420
421 fn total_duration(&self) -> Option<std::time::Duration> {
422 self.inner.total_duration()
423 }
424}
425
426/// constant source, only works on a single span
427pub struct InspectBuffer<const N: usize, S, F>
428where
429 S: Source + Sized,
430 F: FnMut(&[Sample; N]),
431{
432 inner: S,
433 callback: F,
434 /// Stores already emitted samples, once its full we call the callback.
435 buffer: [Sample; N],
436 /// Next free element in buffer. If this is equal to the buffer length
437 /// we have no more free lements.
438 free: usize,
439}
440
441impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
442where
443 S: Source + Sized,
444 F: FnMut(&[Sample; N]),
445{
446 type Item = Sample;
447
448 fn next(&mut self) -> Option<Self::Item> {
449 let Some(sample) = self.inner.next() else {
450 return None;
451 };
452
453 self.buffer[self.free] = sample;
454 self.free += 1;
455
456 if self.free == self.buffer.len() {
457 (self.callback)(&self.buffer);
458 self.free = 0
459 }
460
461 Some(sample)
462 }
463
464 fn size_hint(&self) -> (usize, Option<usize>) {
465 self.inner.size_hint()
466 }
467}
468
469impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
470where
471 S: Source + Sized,
472 F: FnMut(&[Sample; N]),
473{
474 fn current_span_len(&self) -> Option<usize> {
475 None
476 }
477
478 fn channels(&self) -> rodio::ChannelCount {
479 self.inner.channels()
480 }
481
482 fn sample_rate(&self) -> rodio::SampleRate {
483 self.inner.sample_rate()
484 }
485
486 fn total_duration(&self) -> Option<std::time::Duration> {
487 self.inner.total_duration()
488 }
489}
490
491/// constant source, only works on a single span
492#[derive(Debug)]
493pub struct Replayable<S: Source> {
494 inner: S,
495 buffer: Vec<Sample>,
496 chunk_size: usize,
497 tx: Arc<ReplayQueue>,
498 is_active: Arc<AtomicBool>,
499}
500
501impl<S: Source> Iterator for Replayable<S> {
502 type Item = Sample;
503
504 fn next(&mut self) -> Option<Self::Item> {
505 if let Some(sample) = self.inner.next() {
506 self.buffer.push(sample);
507 // If the buffer is full send it
508 if self.buffer.len() == self.chunk_size {
509 self.tx.push_normal(std::mem::take(&mut self.buffer));
510 }
511 Some(sample)
512 } else {
513 let last_chunk = std::mem::take(&mut self.buffer);
514 self.tx.push_last(last_chunk);
515 self.is_active.store(false, Ordering::Relaxed);
516 None
517 }
518 }
519
520 fn size_hint(&self) -> (usize, Option<usize>) {
521 self.inner.size_hint()
522 }
523}
524
525impl<S: Source> Source for Replayable<S> {
526 fn current_span_len(&self) -> Option<usize> {
527 self.inner.current_span_len()
528 }
529
530 fn channels(&self) -> ChannelCount {
531 self.inner.channels()
532 }
533
534 fn sample_rate(&self) -> SampleRate {
535 self.inner.sample_rate()
536 }
537
538 fn total_duration(&self) -> Option<Duration> {
539 self.inner.total_duration()
540 }
541}
542
543/// constant source, only works on a single span
544#[derive(Debug)]
545pub struct Replay {
546 rx: Arc<ReplayQueue>,
547 buffer: std::vec::IntoIter<Sample>,
548 sleep_duration: Duration,
549 sample_rate: SampleRate,
550 channel_count: ChannelCount,
551 source_is_active: Arc<AtomicBool>,
552}
553
554impl Replay {
555 pub fn source_is_active(&self) -> bool {
556 // - source could return None and not drop
557 // - source could be dropped before returning None
558 self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
559 }
560
561 /// Duration of what is in the buffer and can be returned without blocking.
562 pub fn duration_ready(&self) -> Duration {
563 let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
564
565 let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
566 Duration::from_secs_f64(seconds_queued)
567 }
568
569 /// Number of samples in the buffer and can be returned without blocking.
570 pub fn samples_ready(&self) -> usize {
571 self.rx.len() + self.buffer.len()
572 }
573}
574
575impl Iterator for Replay {
576 type Item = Sample;
577
578 fn next(&mut self) -> Option<Self::Item> {
579 if let Some(sample) = self.buffer.next() {
580 return Some(sample);
581 }
582
583 loop {
584 if let Some(new_buffer) = self.rx.pop() {
585 self.buffer = new_buffer.into_iter();
586 return self.buffer.next();
587 }
588
589 if !self.source_is_active() {
590 return None;
591 }
592
593 // The queue does not support blocking on a next item. We want this queue as it
594 // is quite fast and provides a fixed size. We know how many samples are in a
595 // buffer so if we do not get one now we must be getting one after `sleep_duration`.
596 std::thread::sleep(self.sleep_duration);
597 }
598 }
599
600 fn size_hint(&self) -> (usize, Option<usize>) {
601 ((self.rx.len() + self.buffer.len()), None)
602 }
603}
604
605impl Source for Replay {
606 fn current_span_len(&self) -> Option<usize> {
607 None // source is not compatible with spans
608 }
609
610 fn channels(&self) -> ChannelCount {
611 self.channel_count
612 }
613
614 fn sample_rate(&self) -> SampleRate {
615 self.sample_rate
616 }
617
618 fn total_duration(&self) -> Option<Duration> {
619 None
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use rodio::{nz, static_buffer::StaticSamplesBuffer};
626
627 use super::*;
628
629 const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
630
631 fn test_source() -> StaticSamplesBuffer {
632 StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
633 }
634
635 mod process_buffer {
636 use super::*;
637
638 #[test]
639 fn callback_gets_all_samples() {
640 let input = test_source();
641
642 let _ = input
643 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
644 .count();
645 }
646 #[test]
647 fn callback_modifies_yielded() {
648 let input = test_source();
649
650 let yielded: Vec<_> = input
651 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
652 for sample in buffer {
653 *sample += 1.0;
654 }
655 })
656 .collect();
657 assert_eq!(
658 yielded,
659 SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
660 )
661 }
662 #[test]
663 fn source_truncates_to_whole_buffers() {
664 let input = test_source();
665
666 let yielded = input
667 .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
668 .count();
669 assert_eq!(yielded, 3)
670 }
671 }
672
673 mod inspect_buffer {
674 use super::*;
675
676 #[test]
677 fn callback_gets_all_samples() {
678 let input = test_source();
679
680 let _ = input
681 .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
682 .count();
683 }
684 #[test]
685 fn source_does_not_truncate() {
686 let input = test_source();
687
688 let yielded = input
689 .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
690 .count();
691 assert_eq!(yielded, SAMPLES.len())
692 }
693 }
694
695 mod instant_replay {
696 use super::*;
697
698 #[test]
699 fn continues_after_history() {
700 let input = test_source();
701
702 let (mut replay, mut source) = input
703 .replayable(Duration::from_secs(3))
704 .expect("longer than 100ms");
705
706 source.by_ref().take(3).count();
707 let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
708 assert_eq!(&yielded, &SAMPLES[0..3],);
709
710 source.count();
711 let yielded: Vec<Sample> = replay.collect();
712 assert_eq!(&yielded, &SAMPLES[3..5],);
713 }
714
715 #[test]
716 fn keeps_only_latest() {
717 let input = test_source();
718
719 let (mut replay, mut source) = input
720 .replayable(Duration::from_secs(2))
721 .expect("longer than 100ms");
722
723 source.by_ref().take(5).count(); // get all items but do not end the source
724 let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
725 assert_eq!(&yielded, &SAMPLES[3..5]);
726 source.count(); // exhaust source
727 assert_eq!(replay.next(), None);
728 }
729
730 #[test]
731 fn keeps_correct_amount_of_seconds() {
732 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
733
734 let (replay, mut source) = input
735 .replayable(Duration::from_secs(2))
736 .expect("longer than 100ms");
737
738 // exhaust but do not yet end source
739 source.by_ref().take(40_000).count();
740
741 // take all samples we can without blocking
742 let ready = replay.samples_ready();
743 let n_yielded = replay.take_samples(ready).count();
744
745 let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
746 let margin = 16_000 / 10; // 100ms
747 assert!(n_yielded as u32 >= max - margin);
748 }
749
750 #[test]
751 fn samples_ready() {
752 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
753 let (mut replay, source) = input
754 .replayable(Duration::from_secs(2))
755 .expect("longer than 100ms");
756 assert_eq!(replay.by_ref().samples_ready(), 0);
757
758 source.take(8000).count(); // half a second
759 let margin = 16_000 / 10; // 100ms
760 let ready = replay.samples_ready();
761 assert!(ready >= 8000 - margin);
762 }
763 }
764}