1use std::{
2 f32,
3 sync::{
4 Arc, Mutex,
5 atomic::{AtomicBool, Ordering},
6 },
7 time::Duration,
8};
9
10use crossbeam::queue::ArrayQueue;
11use denoise::{Denoiser, DenoiserError};
12use rodio::{ChannelCount, Sample, SampleRate, Source, source::UniformSourceIterator};
13
14#[derive(Debug, thiserror::Error)]
15#[error("Replay duration is too short must be >= 100ms")]
16pub struct ReplayDurationTooShort;
17
18pub trait RodioExt: Source + Sized {
19 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
20 where
21 F: FnMut(&mut [Sample; N]);
22 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
23 where
24 F: FnMut(&[Sample; N]);
25 fn replayable(
26 self,
27 duration: Duration,
28 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
29 fn take_samples(self, n: usize) -> TakeSamples<Self>;
30 fn denoise(self) -> Result<Denoiser<Self>, DenoiserError>;
31 fn constant_params(
32 self,
33 channel_count: ChannelCount,
34 sample_rate: SampleRate,
35 ) -> UniformSourceIterator<Self>;
36 fn suspicious_stereo_to_mono(self) -> ToMono<Self>;
37}
38
39impl<S: Source> RodioExt for S {
40 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
41 where
42 F: FnMut(&mut [Sample; N]),
43 {
44 ProcessBuffer {
45 inner: self,
46 callback,
47 buffer: [0.0; N],
48 next: N,
49 }
50 }
51 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
52 where
53 F: FnMut(&[Sample; N]),
54 {
55 InspectBuffer {
56 inner: self,
57 callback,
58 buffer: [0.0; N],
59 free: 0,
60 }
61 }
62 /// Maintains a live replay with a history of at least `duration` seconds.
63 ///
64 /// Note:
65 /// History can be 100ms longer if the source drops before or while the
66 /// replay is being read
67 ///
68 /// # Errors
69 /// If duration is smaller then 100ms
70 fn replayable(
71 self,
72 duration: Duration,
73 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
74 if duration < Duration::from_millis(100) {
75 return Err(ReplayDurationTooShort);
76 }
77
78 let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
79 let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
80 let samples_to_queue =
81 (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
82
83 let chunk_size =
84 (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
85 let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
86
87 let is_active = Arc::new(AtomicBool::new(true));
88 let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
89 Ok((
90 Replay {
91 rx: Arc::clone(&queue),
92 buffer: Vec::new().into_iter(),
93 sleep_duration: duration / 2,
94 sample_rate: self.sample_rate(),
95 channel_count: self.channels(),
96 source_is_active: is_active.clone(),
97 },
98 Replayable {
99 tx: queue,
100 inner: self,
101 buffer: Vec::with_capacity(chunk_size),
102 chunk_size,
103 is_active,
104 },
105 ))
106 }
107 fn take_samples(self, n: usize) -> TakeSamples<S> {
108 TakeSamples {
109 inner: self,
110 left_to_take: n,
111 }
112 }
113 fn denoise(self) -> Result<Denoiser<Self>, DenoiserError> {
114 let res = Denoiser::try_new(self);
115 log::info!("result of new: {res:?}");
116 res
117 }
118 fn constant_params(
119 self,
120 channel_count: ChannelCount,
121 sample_rate: SampleRate,
122 ) -> UniformSourceIterator<Self> {
123 UniformSourceIterator::new(self, channel_count, sample_rate)
124 }
125 fn suspicious_stereo_to_mono(self) -> ToMono<Self> {
126 ToMono {
127 input_channel_count: self.channels(),
128 inner: self,
129 mean: f32::EPSILON * 3.0,
130 }
131 }
132}
133
134/// constant source, only works on a single span
135pub struct ToMono<S> {
136 inner: S,
137 input_channel_count: ChannelCount,
138 /// running mean of second channel 'volume'
139 mean: f32,
140}
141impl<S> ToMono<S> {
142 fn real_stereo(&self) -> bool {
143 dbg!(self.mean);
144 dbg!(self.mean >= f32::EPSILON * 3.0)
145 }
146
147 fn update_mean(&mut self, second_channel: Sample) {
148 const HISTORY: f32 = 500.0;
149 self.mean *= (HISTORY - 1.0) / HISTORY;
150 self.mean += second_channel.abs() / HISTORY;
151 }
152}
153
154impl<S: Source> Source for ToMono<S> {
155 fn current_span_len(&self) -> Option<usize> {
156 None
157 }
158
159 fn channels(&self) -> ChannelCount {
160 rodio::nz!(1)
161 }
162
163 fn sample_rate(&self) -> SampleRate {
164 self.inner.sample_rate()
165 }
166
167 fn total_duration(&self) -> Option<Duration> {
168 self.inner.total_duration()
169 }
170}
171
172impl<S: Source> Iterator for ToMono<S> {
173 type Item = Sample;
174
175 fn next(&mut self) -> Option<Self::Item> {
176 match self.input_channel_count.get() {
177 1 => self.next(),
178 2 => {
179 let first_channel = self.next()?;
180 let second_channel = self.next()?;
181 self.update_mean(second_channel);
182
183 if self.real_stereo() {
184 Some((first_channel + second_channel) / 2.0)
185 } else {
186 Some(first_channel)
187 }
188 }
189 _ => todo!("unsupported channel count"),
190 }
191 }
192}
193
194/// constant source, only works on a single span
195pub struct TakeSamples<S> {
196 inner: S,
197 left_to_take: usize,
198}
199
200impl<S: Source> Iterator for TakeSamples<S> {
201 type Item = Sample;
202
203 fn next(&mut self) -> Option<Self::Item> {
204 if self.left_to_take == 0 {
205 None
206 } else {
207 self.left_to_take -= 1;
208 self.inner.next()
209 }
210 }
211
212 fn size_hint(&self) -> (usize, Option<usize>) {
213 (0, Some(self.left_to_take))
214 }
215}
216
217impl<S: Source> Source for TakeSamples<S> {
218 fn current_span_len(&self) -> Option<usize> {
219 None // does not support spans
220 }
221
222 fn channels(&self) -> ChannelCount {
223 self.inner.channels()
224 }
225
226 fn sample_rate(&self) -> SampleRate {
227 self.inner.sample_rate()
228 }
229
230 fn total_duration(&self) -> Option<Duration> {
231 Some(Duration::from_secs_f64(
232 self.left_to_take as f64
233 / self.sample_rate().get() as f64
234 / self.channels().get() as f64,
235 ))
236 }
237}
238
239/// constant source, only works on a single span
240#[derive(Debug)]
241struct ReplayQueue {
242 inner: ArrayQueue<Vec<Sample>>,
243 normal_chunk_len: usize,
244 /// The last chunk in the queue may be smaller then
245 /// the normal chunk size. This is always equal to the
246 /// size of the last element in the queue.
247 /// (so normally chunk_size)
248 last_chunk: Mutex<Vec<Sample>>,
249}
250
251impl ReplayQueue {
252 fn new(queue_len: usize, chunk_size: usize) -> Self {
253 Self {
254 inner: ArrayQueue::new(queue_len),
255 normal_chunk_len: chunk_size,
256 last_chunk: Mutex::new(Vec::new()),
257 }
258 }
259 /// Returns the length in samples
260 fn len(&self) -> usize {
261 self.inner.len().saturating_sub(1) * self.normal_chunk_len
262 + self
263 .last_chunk
264 .lock()
265 .expect("Self::push_last can not poison this lock")
266 .len()
267 }
268
269 fn pop(&self) -> Option<Vec<Sample>> {
270 self.inner.pop() // removes element that was inserted first
271 }
272
273 fn push_last(&self, mut samples: Vec<Sample>) {
274 let mut last_chunk = self
275 .last_chunk
276 .lock()
277 .expect("Self::len can not poison this lock");
278 std::mem::swap(&mut *last_chunk, &mut samples);
279 }
280
281 fn push_normal(&self, samples: Vec<Sample>) {
282 let _pushed_out_of_ringbuf = self.inner.force_push(samples);
283 }
284}
285
286/// constant source, only works on a single span
287pub struct ProcessBuffer<const N: usize, S, F>
288where
289 S: Source + Sized,
290 F: FnMut(&mut [Sample; N]),
291{
292 inner: S,
293 callback: F,
294 /// Buffer used for both input and output.
295 buffer: [Sample; N],
296 /// Next already processed sample is at this index
297 /// in buffer.
298 ///
299 /// If this is equal to the length of the buffer we have no more samples and
300 /// we must get new ones and process them
301 next: usize,
302}
303
304impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
305where
306 S: Source + Sized,
307 F: FnMut(&mut [Sample; N]),
308{
309 type Item = Sample;
310
311 fn next(&mut self) -> Option<Self::Item> {
312 self.next += 1;
313 if self.next < self.buffer.len() {
314 let sample = self.buffer[self.next];
315 return Some(sample);
316 }
317
318 for sample in &mut self.buffer {
319 *sample = self.inner.next()?
320 }
321 (self.callback)(&mut self.buffer);
322
323 self.next = 0;
324 Some(self.buffer[0])
325 }
326
327 fn size_hint(&self) -> (usize, Option<usize>) {
328 self.inner.size_hint()
329 }
330}
331
332impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
333where
334 S: Source + Sized,
335 F: FnMut(&mut [Sample; N]),
336{
337 fn current_span_len(&self) -> Option<usize> {
338 None
339 }
340
341 fn channels(&self) -> rodio::ChannelCount {
342 self.inner.channels()
343 }
344
345 fn sample_rate(&self) -> rodio::SampleRate {
346 self.inner.sample_rate()
347 }
348
349 fn total_duration(&self) -> Option<std::time::Duration> {
350 self.inner.total_duration()
351 }
352}
353
354/// constant source, only works on a single span
355pub struct InspectBuffer<const N: usize, S, F>
356where
357 S: Source + Sized,
358 F: FnMut(&[Sample; N]),
359{
360 inner: S,
361 callback: F,
362 /// Stores already emitted samples, once its full we call the callback.
363 buffer: [Sample; N],
364 /// Next free element in buffer. If this is equal to the buffer length
365 /// we have no more free lements.
366 free: usize,
367}
368
369impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
370where
371 S: Source + Sized,
372 F: FnMut(&[Sample; N]),
373{
374 type Item = Sample;
375
376 fn next(&mut self) -> Option<Self::Item> {
377 let Some(sample) = self.inner.next() else {
378 return None;
379 };
380
381 self.buffer[self.free] = sample;
382 self.free += 1;
383
384 if self.free == self.buffer.len() {
385 (self.callback)(&self.buffer);
386 self.free = 0
387 }
388
389 Some(sample)
390 }
391
392 fn size_hint(&self) -> (usize, Option<usize>) {
393 self.inner.size_hint()
394 }
395}
396
397impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
398where
399 S: Source + Sized,
400 F: FnMut(&[Sample; N]),
401{
402 fn current_span_len(&self) -> Option<usize> {
403 None
404 }
405
406 fn channels(&self) -> rodio::ChannelCount {
407 self.inner.channels()
408 }
409
410 fn sample_rate(&self) -> rodio::SampleRate {
411 self.inner.sample_rate()
412 }
413
414 fn total_duration(&self) -> Option<std::time::Duration> {
415 self.inner.total_duration()
416 }
417}
418
419/// constant source, only works on a single span
420#[derive(Debug)]
421pub struct Replayable<S: Source> {
422 inner: S,
423 buffer: Vec<Sample>,
424 chunk_size: usize,
425 tx: Arc<ReplayQueue>,
426 is_active: Arc<AtomicBool>,
427}
428
429impl<S: Source> Iterator for Replayable<S> {
430 type Item = Sample;
431
432 fn next(&mut self) -> Option<Self::Item> {
433 if let Some(sample) = self.inner.next() {
434 self.buffer.push(sample);
435 // If the buffer is full send it
436 if self.buffer.len() == self.chunk_size {
437 self.tx.push_normal(std::mem::take(&mut self.buffer));
438 }
439 Some(sample)
440 } else {
441 let last_chunk = std::mem::take(&mut self.buffer);
442 self.tx.push_last(last_chunk);
443 self.is_active.store(false, Ordering::Relaxed);
444 None
445 }
446 }
447
448 fn size_hint(&self) -> (usize, Option<usize>) {
449 self.inner.size_hint()
450 }
451}
452
453impl<S: Source> Source for Replayable<S> {
454 fn current_span_len(&self) -> Option<usize> {
455 self.inner.current_span_len()
456 }
457
458 fn channels(&self) -> ChannelCount {
459 self.inner.channels()
460 }
461
462 fn sample_rate(&self) -> SampleRate {
463 self.inner.sample_rate()
464 }
465
466 fn total_duration(&self) -> Option<Duration> {
467 self.inner.total_duration()
468 }
469}
470
471/// constant source, only works on a single span
472#[derive(Debug)]
473pub struct Replay {
474 rx: Arc<ReplayQueue>,
475 buffer: std::vec::IntoIter<Sample>,
476 sleep_duration: Duration,
477 sample_rate: SampleRate,
478 channel_count: ChannelCount,
479 source_is_active: Arc<AtomicBool>,
480}
481
482impl Replay {
483 pub fn source_is_active(&self) -> bool {
484 // - source could return None and not drop
485 // - source could be dropped before returning None
486 self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
487 }
488
489 /// Duration of what is in the buffer and can be returned without blocking.
490 pub fn duration_ready(&self) -> Duration {
491 let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
492
493 let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
494 Duration::from_secs_f64(seconds_queued)
495 }
496
497 /// Number of samples in the buffer and can be returned without blocking.
498 pub fn samples_ready(&self) -> usize {
499 self.rx.len() + self.buffer.len()
500 }
501}
502
503impl Iterator for Replay {
504 type Item = Sample;
505
506 fn next(&mut self) -> Option<Self::Item> {
507 if let Some(sample) = self.buffer.next() {
508 return Some(sample);
509 }
510
511 loop {
512 if let Some(new_buffer) = self.rx.pop() {
513 self.buffer = new_buffer.into_iter();
514 return self.buffer.next();
515 }
516
517 if !self.source_is_active() {
518 return None;
519 }
520
521 // The queue does not support blocking on a next item. We want this queue as it
522 // is quite fast and provides a fixed size. We know how many samples are in a
523 // buffer so if we do not get one now we must be getting one after `sleep_duration`.
524 std::thread::sleep(self.sleep_duration);
525 }
526 }
527
528 fn size_hint(&self) -> (usize, Option<usize>) {
529 ((self.rx.len() + self.buffer.len()), None)
530 }
531}
532
533impl Source for Replay {
534 fn current_span_len(&self) -> Option<usize> {
535 None // source is not compatible with spans
536 }
537
538 fn channels(&self) -> ChannelCount {
539 self.channel_count
540 }
541
542 fn sample_rate(&self) -> SampleRate {
543 self.sample_rate
544 }
545
546 fn total_duration(&self) -> Option<Duration> {
547 None
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use rodio::{nz, static_buffer::StaticSamplesBuffer};
554
555 use super::*;
556
557 const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
558
559 fn test_source() -> StaticSamplesBuffer {
560 StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
561 }
562
563 mod process_buffer {
564 use super::*;
565
566 #[test]
567 fn callback_gets_all_samples() {
568 let input = test_source();
569
570 let _ = input
571 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
572 .count();
573 }
574 #[test]
575 fn callback_modifies_yielded() {
576 let input = test_source();
577
578 let yielded: Vec<_> = input
579 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
580 for sample in buffer {
581 *sample += 1.0;
582 }
583 })
584 .collect();
585 assert_eq!(
586 yielded,
587 SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
588 )
589 }
590 #[test]
591 fn source_truncates_to_whole_buffers() {
592 let input = test_source();
593
594 let yielded = input
595 .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
596 .count();
597 assert_eq!(yielded, 3)
598 }
599 }
600
601 mod inspect_buffer {
602 use super::*;
603
604 #[test]
605 fn callback_gets_all_samples() {
606 let input = test_source();
607
608 let _ = input
609 .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
610 .count();
611 }
612 #[test]
613 fn source_does_not_truncate() {
614 let input = test_source();
615
616 let yielded = input
617 .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
618 .count();
619 assert_eq!(yielded, SAMPLES.len())
620 }
621 }
622
623 mod instant_replay {
624 use super::*;
625
626 #[test]
627 fn continues_after_history() {
628 let input = test_source();
629
630 let (mut replay, mut source) = input
631 .replayable(Duration::from_secs(3))
632 .expect("longer then 100ms");
633
634 source.by_ref().take(3).count();
635 let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
636 assert_eq!(&yielded, &SAMPLES[0..3],);
637
638 source.count();
639 let yielded: Vec<Sample> = replay.collect();
640 assert_eq!(&yielded, &SAMPLES[3..5],);
641 }
642
643 #[test]
644 fn keeps_only_latest() {
645 let input = test_source();
646
647 let (mut replay, mut source) = input
648 .replayable(Duration::from_secs(2))
649 .expect("longer then 100ms");
650
651 source.by_ref().take(5).count(); // get all items but do not end the source
652 let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
653 assert_eq!(&yielded, &SAMPLES[3..5]);
654 source.count(); // exhaust source
655 assert_eq!(replay.next(), None);
656 }
657
658 #[test]
659 fn keeps_correct_amount_of_seconds() {
660 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
661
662 let (replay, mut source) = input
663 .replayable(Duration::from_secs(2))
664 .expect("longer then 100ms");
665
666 // exhaust but do not yet end source
667 source.by_ref().take(40_000).count();
668
669 // take all samples we can without blocking
670 let ready = replay.samples_ready();
671 let n_yielded = replay.take_samples(ready).count();
672
673 let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
674 let margin = 16_000 / 10; // 100ms
675 assert!(n_yielded as u32 >= max - margin);
676 }
677
678 #[test]
679 fn samples_ready() {
680 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
681 let (mut replay, source) = input
682 .replayable(Duration::from_secs(2))
683 .expect("longer then 100ms");
684 assert_eq!(replay.by_ref().samples_ready(), 0);
685
686 source.take(8000).count(); // half a second
687 let margin = 16_000 / 10; // 100ms
688 let ready = replay.samples_ready();
689 assert!(ready >= 8000 - margin);
690 }
691 }
692}