1use std::{
2 num::NonZero,
3 sync::{
4 Arc, Mutex,
5 atomic::{AtomicBool, Ordering},
6 },
7 time::Duration,
8};
9
10use crossbeam::queue::ArrayQueue;
11use denoise::{Denoiser, DenoiserError};
12use log::warn;
13use rodio::{
14 ChannelCount, Sample, SampleRate, Source, conversions::SampleRateConverter, nz,
15 source::UniformSourceIterator,
16};
17
18const MAX_CHANNELS: usize = 8;
19
20#[derive(Debug, thiserror::Error)]
21#[error("Replay duration is too short must be >= 100ms")]
22pub struct ReplayDurationTooShort;
23
24// These all require constant sources (so the span is infinitely long)
25// this is not guaranteed by rodio however we know it to be true in all our
26// applications. Rodio desperately needs a constant source concept.
27pub trait RodioExt: Source + Sized {
28 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
29 where
30 F: FnMut(&mut [Sample; N]);
31 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
32 where
33 F: FnMut(&[Sample; N]);
34 fn replayable(
35 self,
36 duration: Duration,
37 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
38 fn take_samples(self, n: usize) -> TakeSamples<Self>;
39 fn denoise(self) -> Result<Denoiser<Self>, DenoiserError>;
40 fn constant_params(
41 self,
42 channel_count: ChannelCount,
43 sample_rate: SampleRate,
44 ) -> UniformSourceIterator<Self>;
45 fn constant_samplerate(self, sample_rate: SampleRate) -> ConstantSampleRate<Self>;
46 fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self>;
47}
48
49impl<S: Source> RodioExt for S {
50 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
51 where
52 F: FnMut(&mut [Sample; N]),
53 {
54 ProcessBuffer {
55 inner: self,
56 callback,
57 buffer: [0.0; N],
58 next: N,
59 }
60 }
61 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
62 where
63 F: FnMut(&[Sample; N]),
64 {
65 InspectBuffer {
66 inner: self,
67 callback,
68 buffer: [0.0; N],
69 free: 0,
70 }
71 }
72 /// Maintains a live replay with a history of at least `duration` seconds.
73 ///
74 /// Note:
75 /// History can be 100ms longer if the source drops before or while the
76 /// replay is being read
77 ///
78 /// # Errors
79 /// If duration is smaller than 100ms
80 fn replayable(
81 self,
82 duration: Duration,
83 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
84 if duration < Duration::from_millis(100) {
85 return Err(ReplayDurationTooShort);
86 }
87
88 let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
89 let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
90 let samples_to_queue =
91 (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
92
93 let chunk_size =
94 (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
95 let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
96
97 let is_active = Arc::new(AtomicBool::new(true));
98 let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
99 Ok((
100 Replay {
101 rx: Arc::clone(&queue),
102 buffer: Vec::new().into_iter(),
103 sleep_duration: duration / 2,
104 sample_rate: self.sample_rate(),
105 channel_count: self.channels(),
106 source_is_active: is_active.clone(),
107 },
108 Replayable {
109 tx: queue,
110 inner: self,
111 buffer: Vec::with_capacity(chunk_size),
112 chunk_size,
113 is_active,
114 },
115 ))
116 }
117 fn take_samples(self, n: usize) -> TakeSamples<S> {
118 TakeSamples {
119 inner: self,
120 left_to_take: n,
121 }
122 }
123 fn denoise(self) -> Result<Denoiser<Self>, DenoiserError> {
124 let res = Denoiser::try_new(self);
125 res
126 }
127 fn constant_params(
128 self,
129 channel_count: ChannelCount,
130 sample_rate: SampleRate,
131 ) -> UniformSourceIterator<Self> {
132 UniformSourceIterator::new(self, channel_count, sample_rate)
133 }
134 fn constant_samplerate(self, sample_rate: SampleRate) -> ConstantSampleRate<Self> {
135 ConstantSampleRate::new(self, sample_rate)
136 }
137 fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self> {
138 ToMono::new(self)
139 }
140}
141
142pub struct ConstantSampleRate<S: Source> {
143 inner: SampleRateConverter<S>,
144 channels: ChannelCount,
145 sample_rate: SampleRate,
146}
147
148impl<S: Source> ConstantSampleRate<S> {
149 fn new(source: S, target_rate: SampleRate) -> Self {
150 let input_sample_rate = source.sample_rate();
151 let channels = source.channels();
152 let inner = SampleRateConverter::new(source, input_sample_rate, target_rate, channels);
153 Self {
154 inner,
155 channels,
156 sample_rate: target_rate,
157 }
158 }
159}
160
161impl<S: Source> Iterator for ConstantSampleRate<S> {
162 type Item = rodio::Sample;
163
164 fn next(&mut self) -> Option<Self::Item> {
165 self.inner.next()
166 }
167
168 fn size_hint(&self) -> (usize, Option<usize>) {
169 self.inner.size_hint()
170 }
171}
172
173impl<S: Source> Source for ConstantSampleRate<S> {
174 fn current_span_len(&self) -> Option<usize> {
175 None
176 }
177
178 fn channels(&self) -> ChannelCount {
179 self.channels
180 }
181
182 fn sample_rate(&self) -> SampleRate {
183 self.sample_rate
184 }
185
186 fn total_duration(&self) -> Option<Duration> {
187 None // not supported (not used by us)
188 }
189}
190
191const TYPICAL_NOISE_FLOOR: Sample = 1e-3;
192
193/// constant source, only works on a single span
194pub struct ToMono<S> {
195 inner: S,
196 input_channel_count: ChannelCount,
197 connected_channels: ChannelCount,
198 /// running mean of second channel 'volume'
199 means: [f32; MAX_CHANNELS],
200}
201impl<S: Source> ToMono<S> {
202 fn new(input: S) -> Self {
203 let channels = input
204 .channels()
205 .min(const { NonZero::<u16>::new(MAX_CHANNELS as u16).unwrap() });
206 if channels < input.channels() {
207 warn!("Ignoring input channels {}..", channels.get());
208 }
209
210 Self {
211 connected_channels: channels,
212 input_channel_count: channels,
213 inner: input,
214 means: [TYPICAL_NOISE_FLOOR; MAX_CHANNELS],
215 }
216 }
217}
218
219impl<S: Source> Source for ToMono<S> {
220 fn current_span_len(&self) -> Option<usize> {
221 None
222 }
223
224 fn channels(&self) -> ChannelCount {
225 rodio::nz!(1)
226 }
227
228 fn sample_rate(&self) -> SampleRate {
229 self.inner.sample_rate()
230 }
231
232 fn total_duration(&self) -> Option<Duration> {
233 self.inner.total_duration()
234 }
235}
236
237fn update_mean(mean: &mut f32, sample: Sample) {
238 const HISTORY: f32 = 500.0;
239 *mean *= (HISTORY - 1.0) / HISTORY;
240 *mean += sample.abs() / HISTORY;
241}
242
243impl<S: Source> Iterator for ToMono<S> {
244 type Item = Sample;
245
246 fn next(&mut self) -> Option<Self::Item> {
247 let mut mono_sample = 0f32;
248 let mut active_channels = 0;
249 for channel in 0..self.input_channel_count.get() as usize {
250 let sample = self.inner.next()?;
251 mono_sample += sample;
252
253 update_mean(&mut self.means[channel], sample);
254 if self.means[channel] > TYPICAL_NOISE_FLOOR / 10.0 {
255 active_channels += 1;
256 }
257 }
258 mono_sample /= self.connected_channels.get() as f32;
259 self.connected_channels = NonZero::new(active_channels).unwrap_or(nz!(1));
260
261 Some(mono_sample)
262 }
263}
264
265/// constant source, only works on a single span
266pub struct TakeSamples<S> {
267 inner: S,
268 left_to_take: usize,
269}
270
271impl<S: Source> Iterator for TakeSamples<S> {
272 type Item = Sample;
273
274 fn next(&mut self) -> Option<Self::Item> {
275 if self.left_to_take == 0 {
276 None
277 } else {
278 self.left_to_take -= 1;
279 self.inner.next()
280 }
281 }
282
283 fn size_hint(&self) -> (usize, Option<usize>) {
284 (0, Some(self.left_to_take))
285 }
286}
287
288impl<S: Source> Source for TakeSamples<S> {
289 fn current_span_len(&self) -> Option<usize> {
290 None // does not support spans
291 }
292
293 fn channels(&self) -> ChannelCount {
294 self.inner.channels()
295 }
296
297 fn sample_rate(&self) -> SampleRate {
298 self.inner.sample_rate()
299 }
300
301 fn total_duration(&self) -> Option<Duration> {
302 Some(Duration::from_secs_f64(
303 self.left_to_take as f64
304 / self.sample_rate().get() as f64
305 / self.channels().get() as f64,
306 ))
307 }
308}
309
310/// constant source, only works on a single span
311#[derive(Debug)]
312struct ReplayQueue {
313 inner: ArrayQueue<Vec<Sample>>,
314 normal_chunk_len: usize,
315 /// The last chunk in the queue may be smaller than
316 /// the normal chunk size. This is always equal to the
317 /// size of the last element in the queue.
318 /// (so normally chunk_size)
319 last_chunk: Mutex<Vec<Sample>>,
320}
321
322impl ReplayQueue {
323 fn new(queue_len: usize, chunk_size: usize) -> Self {
324 Self {
325 inner: ArrayQueue::new(queue_len),
326 normal_chunk_len: chunk_size,
327 last_chunk: Mutex::new(Vec::new()),
328 }
329 }
330 /// Returns the length in samples
331 fn len(&self) -> usize {
332 self.inner.len().saturating_sub(1) * self.normal_chunk_len
333 + self
334 .last_chunk
335 .lock()
336 .expect("Self::push_last can not poison this lock")
337 .len()
338 }
339
340 fn pop(&self) -> Option<Vec<Sample>> {
341 self.inner.pop() // removes element that was inserted first
342 }
343
344 fn push_last(&self, mut samples: Vec<Sample>) {
345 let mut last_chunk = self
346 .last_chunk
347 .lock()
348 .expect("Self::len can not poison this lock");
349 std::mem::swap(&mut *last_chunk, &mut samples);
350 }
351
352 fn push_normal(&self, samples: Vec<Sample>) {
353 let _pushed_out_of_ringbuf = self.inner.force_push(samples);
354 }
355}
356
357/// constant source, only works on a single span
358pub struct ProcessBuffer<const N: usize, S, F>
359where
360 S: Source + Sized,
361 F: FnMut(&mut [Sample; N]),
362{
363 inner: S,
364 callback: F,
365 /// Buffer used for both input and output.
366 buffer: [Sample; N],
367 /// Next already processed sample is at this index
368 /// in buffer.
369 ///
370 /// If this is equal to the length of the buffer we have no more samples and
371 /// we must get new ones and process them
372 next: usize,
373}
374
375impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
376where
377 S: Source + Sized,
378 F: FnMut(&mut [Sample; N]),
379{
380 type Item = Sample;
381
382 fn next(&mut self) -> Option<Self::Item> {
383 self.next += 1;
384 if self.next < self.buffer.len() {
385 let sample = self.buffer[self.next];
386 return Some(sample);
387 }
388
389 for sample in &mut self.buffer {
390 *sample = self.inner.next()?
391 }
392 (self.callback)(&mut self.buffer);
393
394 self.next = 0;
395 Some(self.buffer[0])
396 }
397
398 fn size_hint(&self) -> (usize, Option<usize>) {
399 self.inner.size_hint()
400 }
401}
402
403impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
404where
405 S: Source + Sized,
406 F: FnMut(&mut [Sample; N]),
407{
408 fn current_span_len(&self) -> Option<usize> {
409 None
410 }
411
412 fn channels(&self) -> rodio::ChannelCount {
413 self.inner.channels()
414 }
415
416 fn sample_rate(&self) -> rodio::SampleRate {
417 self.inner.sample_rate()
418 }
419
420 fn total_duration(&self) -> Option<std::time::Duration> {
421 self.inner.total_duration()
422 }
423}
424
425/// constant source, only works on a single span
426pub struct InspectBuffer<const N: usize, S, F>
427where
428 S: Source + Sized,
429 F: FnMut(&[Sample; N]),
430{
431 inner: S,
432 callback: F,
433 /// Stores already emitted samples, once its full we call the callback.
434 buffer: [Sample; N],
435 /// Next free element in buffer. If this is equal to the buffer length
436 /// we have no more free elements.
437 free: usize,
438}
439
440impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
441where
442 S: Source + Sized,
443 F: FnMut(&[Sample; N]),
444{
445 type Item = Sample;
446
447 fn next(&mut self) -> Option<Self::Item> {
448 let Some(sample) = self.inner.next() else {
449 return None;
450 };
451
452 self.buffer[self.free] = sample;
453 self.free += 1;
454
455 if self.free == self.buffer.len() {
456 (self.callback)(&self.buffer);
457 self.free = 0
458 }
459
460 Some(sample)
461 }
462
463 fn size_hint(&self) -> (usize, Option<usize>) {
464 self.inner.size_hint()
465 }
466}
467
468impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
469where
470 S: Source + Sized,
471 F: FnMut(&[Sample; N]),
472{
473 fn current_span_len(&self) -> Option<usize> {
474 None
475 }
476
477 fn channels(&self) -> rodio::ChannelCount {
478 self.inner.channels()
479 }
480
481 fn sample_rate(&self) -> rodio::SampleRate {
482 self.inner.sample_rate()
483 }
484
485 fn total_duration(&self) -> Option<std::time::Duration> {
486 self.inner.total_duration()
487 }
488}
489
490/// constant source, only works on a single span
491#[derive(Debug)]
492pub struct Replayable<S: Source> {
493 inner: S,
494 buffer: Vec<Sample>,
495 chunk_size: usize,
496 tx: Arc<ReplayQueue>,
497 is_active: Arc<AtomicBool>,
498}
499
500impl<S: Source> Iterator for Replayable<S> {
501 type Item = Sample;
502
503 fn next(&mut self) -> Option<Self::Item> {
504 if let Some(sample) = self.inner.next() {
505 self.buffer.push(sample);
506 // If the buffer is full send it
507 if self.buffer.len() == self.chunk_size {
508 self.tx.push_normal(std::mem::take(&mut self.buffer));
509 }
510 Some(sample)
511 } else {
512 let last_chunk = std::mem::take(&mut self.buffer);
513 self.tx.push_last(last_chunk);
514 self.is_active.store(false, Ordering::Relaxed);
515 None
516 }
517 }
518
519 fn size_hint(&self) -> (usize, Option<usize>) {
520 self.inner.size_hint()
521 }
522}
523
524impl<S: Source> Source for Replayable<S> {
525 fn current_span_len(&self) -> Option<usize> {
526 self.inner.current_span_len()
527 }
528
529 fn channels(&self) -> ChannelCount {
530 self.inner.channels()
531 }
532
533 fn sample_rate(&self) -> SampleRate {
534 self.inner.sample_rate()
535 }
536
537 fn total_duration(&self) -> Option<Duration> {
538 self.inner.total_duration()
539 }
540}
541
542/// constant source, only works on a single span
543#[derive(Debug)]
544pub struct Replay {
545 rx: Arc<ReplayQueue>,
546 buffer: std::vec::IntoIter<Sample>,
547 sleep_duration: Duration,
548 sample_rate: SampleRate,
549 channel_count: ChannelCount,
550 source_is_active: Arc<AtomicBool>,
551}
552
553impl Replay {
554 pub fn source_is_active(&self) -> bool {
555 // - source could return None and not drop
556 // - source could be dropped before returning None
557 self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
558 }
559
560 /// Duration of what is in the buffer and can be returned without blocking.
561 pub fn duration_ready(&self) -> Duration {
562 let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
563
564 let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
565 Duration::from_secs_f64(seconds_queued)
566 }
567
568 /// Number of samples in the buffer and can be returned without blocking.
569 pub fn samples_ready(&self) -> usize {
570 self.rx.len() + self.buffer.len()
571 }
572}
573
574impl Iterator for Replay {
575 type Item = Sample;
576
577 fn next(&mut self) -> Option<Self::Item> {
578 if let Some(sample) = self.buffer.next() {
579 return Some(sample);
580 }
581
582 loop {
583 if let Some(new_buffer) = self.rx.pop() {
584 self.buffer = new_buffer.into_iter();
585 return self.buffer.next();
586 }
587
588 if !self.source_is_active() {
589 return None;
590 }
591
592 // The queue does not support blocking on a next item. We want this queue as it
593 // is quite fast and provides a fixed size. We know how many samples are in a
594 // buffer so if we do not get one now we must be getting one after `sleep_duration`.
595 std::thread::sleep(self.sleep_duration);
596 }
597 }
598
599 fn size_hint(&self) -> (usize, Option<usize>) {
600 ((self.rx.len() + self.buffer.len()), None)
601 }
602}
603
604impl Source for Replay {
605 fn current_span_len(&self) -> Option<usize> {
606 None // source is not compatible with spans
607 }
608
609 fn channels(&self) -> ChannelCount {
610 self.channel_count
611 }
612
613 fn sample_rate(&self) -> SampleRate {
614 self.sample_rate
615 }
616
617 fn total_duration(&self) -> Option<Duration> {
618 None
619 }
620}
621
622#[cfg(test)]
623mod tests {
624 use rodio::{nz, static_buffer::StaticSamplesBuffer};
625
626 use super::*;
627
628 const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
629
630 fn test_source() -> StaticSamplesBuffer {
631 StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
632 }
633
634 mod process_buffer {
635 use super::*;
636
637 #[test]
638 fn callback_gets_all_samples() {
639 let input = test_source();
640
641 let _ = input
642 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
643 .count();
644 }
645 #[test]
646 fn callback_modifies_yielded() {
647 let input = test_source();
648
649 let yielded: Vec<_> = input
650 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
651 for sample in buffer {
652 *sample += 1.0;
653 }
654 })
655 .collect();
656 assert_eq!(
657 yielded,
658 SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
659 )
660 }
661 #[test]
662 fn source_truncates_to_whole_buffers() {
663 let input = test_source();
664
665 let yielded = input
666 .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
667 .count();
668 assert_eq!(yielded, 3)
669 }
670 }
671
672 mod inspect_buffer {
673 use super::*;
674
675 #[test]
676 fn callback_gets_all_samples() {
677 let input = test_source();
678
679 let _ = input
680 .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
681 .count();
682 }
683 #[test]
684 fn source_does_not_truncate() {
685 let input = test_source();
686
687 let yielded = input
688 .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
689 .count();
690 assert_eq!(yielded, SAMPLES.len())
691 }
692 }
693
694 mod instant_replay {
695 use super::*;
696
697 #[test]
698 fn continues_after_history() {
699 let input = test_source();
700
701 let (mut replay, mut source) = input
702 .replayable(Duration::from_secs(3))
703 .expect("longer than 100ms");
704
705 source.by_ref().take(3).count();
706 let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
707 assert_eq!(&yielded, &SAMPLES[0..3],);
708
709 source.count();
710 let yielded: Vec<Sample> = replay.collect();
711 assert_eq!(&yielded, &SAMPLES[3..5],);
712 }
713
714 #[test]
715 fn keeps_only_latest() {
716 let input = test_source();
717
718 let (mut replay, mut source) = input
719 .replayable(Duration::from_secs(2))
720 .expect("longer than 100ms");
721
722 source.by_ref().take(5).count(); // get all items but do not end the source
723 let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
724 assert_eq!(&yielded, &SAMPLES[3..5]);
725 source.count(); // exhaust source
726 assert_eq!(replay.next(), None);
727 }
728
729 #[test]
730 fn keeps_correct_amount_of_seconds() {
731 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
732
733 let (replay, mut source) = input
734 .replayable(Duration::from_secs(2))
735 .expect("longer than 100ms");
736
737 // exhaust but do not yet end source
738 source.by_ref().take(40_000).count();
739
740 // take all samples we can without blocking
741 let ready = replay.samples_ready();
742 let n_yielded = replay.take_samples(ready).count();
743
744 let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
745 let margin = 16_000 / 10; // 100ms
746 assert!(n_yielded as u32 >= max - margin);
747 }
748
749 #[test]
750 fn samples_ready() {
751 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
752 let (mut replay, source) = input
753 .replayable(Duration::from_secs(2))
754 .expect("longer than 100ms");
755 assert_eq!(replay.by_ref().samples_ready(), 0);
756
757 source.take(8000).count(); // half a second
758 let margin = 16_000 / 10; // 100ms
759 let ready = replay.samples_ready();
760 assert!(ready >= 8000 - margin);
761 }
762 }
763}