1use std::{
2 f32,
3 sync::{
4 Arc, Mutex,
5 atomic::{AtomicBool, Ordering},
6 },
7 time::Duration,
8};
9
10use crossbeam::queue::ArrayQueue;
11use rodio::{ChannelCount, Sample, SampleRate, Source};
12
13#[derive(Debug, thiserror::Error)]
14#[error("Replay duration is too short must be >= 100ms")]
15pub struct ReplayDurationTooShort;
16
17pub trait RodioExt: Source + Sized {
18 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
19 where
20 F: FnMut(&mut [Sample; N]);
21 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
22 where
23 F: FnMut(&[Sample; N]);
24 fn replayable(
25 self,
26 duration: Duration,
27 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
28 fn take_samples(self, n: usize) -> TakeSamples<Self>;
29 fn denoise(self) -> Result<Denoiser<Self>, DenoiserError>;
30 fn constant_params(
31 self,
32 channel_count: ChannelCount,
33 sample_rate: SampleRate,
34 ) -> UniformSourceIterator<Self>;
35 fn suspicious_stereo_to_mono(self) -> ToMono<Self>;
36}
37
38impl<S: Source> RodioExt for S {
39 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
40 where
41 F: FnMut(&mut [Sample; N]),
42 {
43 ProcessBuffer {
44 inner: self,
45 callback,
46 buffer: [0.0; N],
47 next: N,
48 }
49 }
50 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
51 where
52 F: FnMut(&[Sample; N]),
53 {
54 InspectBuffer {
55 inner: self,
56 callback,
57 buffer: [0.0; N],
58 free: 0,
59 }
60 }
61 /// Maintains a live replay with a history of at least `duration` seconds.
62 ///
63 /// Note:
64 /// History can be 100ms longer if the source drops before or while the
65 /// replay is being read
66 ///
67 /// # Errors
68 /// If duration is smaller than 100ms
69 fn replayable(
70 self,
71 duration: Duration,
72 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
73 if duration < Duration::from_millis(100) {
74 return Err(ReplayDurationTooShort);
75 }
76
77 let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
78 let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
79 let samples_to_queue =
80 (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
81
82 let chunk_size =
83 (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
84 let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
85
86 let is_active = Arc::new(AtomicBool::new(true));
87 let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
88 Ok((
89 Replay {
90 rx: Arc::clone(&queue),
91 buffer: Vec::new().into_iter(),
92 sleep_duration: duration / 2,
93 sample_rate: self.sample_rate(),
94 channel_count: self.channels(),
95 source_is_active: is_active.clone(),
96 },
97 Replayable {
98 tx: queue,
99 inner: self,
100 buffer: Vec::with_capacity(chunk_size),
101 chunk_size,
102 is_active,
103 },
104 ))
105 }
106 fn take_samples(self, n: usize) -> TakeSamples<S> {
107 TakeSamples {
108 inner: self,
109 left_to_take: n,
110 }
111 }
112 fn denoise(self) -> Result<Denoiser<Self>, DenoiserError> {
113 let res = Denoiser::try_new(self);
114 log::info!("result of new: {res:?}");
115 res
116 }
117 fn constant_params(
118 self,
119 channel_count: ChannelCount,
120 sample_rate: SampleRate,
121 ) -> UniformSourceIterator<Self> {
122 UniformSourceIterator::new(self, channel_count, sample_rate)
123 }
124 fn suspicious_stereo_to_mono(self) -> ToMono<Self> {
125 ToMono {
126 input_channel_count: self.channels(),
127 inner: self,
128 mean: f32::EPSILON * 3.0,
129 }
130 }
131}
132
133/// constant source, only works on a single span
134pub struct ToMono<S> {
135 inner: S,
136 input_channel_count: ChannelCount,
137 /// running mean of second channel 'volume'
138 mean: f32,
139}
140impl<S> ToMono<S> {
141 fn real_stereo(&self) -> bool {
142 dbg!(self.mean);
143 dbg!(self.mean >= f32::EPSILON * 3.0)
144 }
145
146 fn update_mean(&mut self, second_channel: Sample) {
147 const HISTORY: f32 = 500.0;
148 self.mean *= (HISTORY - 1.0) / HISTORY;
149 self.mean += second_channel.abs() / HISTORY;
150 }
151}
152
153impl<S: Source> Source for ToMono<S> {
154 fn current_span_len(&self) -> Option<usize> {
155 None
156 }
157
158 fn channels(&self) -> ChannelCount {
159 rodio::nz!(1)
160 }
161
162 fn sample_rate(&self) -> SampleRate {
163 self.inner.sample_rate()
164 }
165
166 fn total_duration(&self) -> Option<Duration> {
167 self.inner.total_duration()
168 }
169}
170
171impl<S: Source> Iterator for ToMono<S> {
172 type Item = Sample;
173
174 fn next(&mut self) -> Option<Self::Item> {
175 match self.input_channel_count.get() {
176 1 => self.next(),
177 2 => {
178 let first_channel = self.next()?;
179 let second_channel = self.next()?;
180 self.update_mean(second_channel);
181
182 if self.real_stereo() {
183 Some((first_channel + second_channel) / 2.0)
184 } else {
185 Some(first_channel)
186 }
187 }
188 _ => todo!("unsupported channel count"),
189 }
190 }
191}
192
193/// constant source, only works on a single span
194pub struct TakeSamples<S> {
195 inner: S,
196 left_to_take: usize,
197}
198
199impl<S: Source> Iterator for TakeSamples<S> {
200 type Item = Sample;
201
202 fn next(&mut self) -> Option<Self::Item> {
203 if self.left_to_take == 0 {
204 None
205 } else {
206 self.left_to_take -= 1;
207 self.inner.next()
208 }
209 }
210
211 fn size_hint(&self) -> (usize, Option<usize>) {
212 (0, Some(self.left_to_take))
213 }
214}
215
216impl<S: Source> Source for TakeSamples<S> {
217 fn current_span_len(&self) -> Option<usize> {
218 None // does not support spans
219 }
220
221 fn channels(&self) -> ChannelCount {
222 self.inner.channels()
223 }
224
225 fn sample_rate(&self) -> SampleRate {
226 self.inner.sample_rate()
227 }
228
229 fn total_duration(&self) -> Option<Duration> {
230 Some(Duration::from_secs_f64(
231 self.left_to_take as f64
232 / self.sample_rate().get() as f64
233 / self.channels().get() as f64,
234 ))
235 }
236}
237
238/// constant source, only works on a single span
239#[derive(Debug)]
240struct ReplayQueue {
241 inner: ArrayQueue<Vec<Sample>>,
242 normal_chunk_len: usize,
243 /// The last chunk in the queue may be smaller than
244 /// the normal chunk size. This is always equal to the
245 /// size of the last element in the queue.
246 /// (so normally chunk_size)
247 last_chunk: Mutex<Vec<Sample>>,
248}
249
250impl ReplayQueue {
251 fn new(queue_len: usize, chunk_size: usize) -> Self {
252 Self {
253 inner: ArrayQueue::new(queue_len),
254 normal_chunk_len: chunk_size,
255 last_chunk: Mutex::new(Vec::new()),
256 }
257 }
258 /// Returns the length in samples
259 fn len(&self) -> usize {
260 self.inner.len().saturating_sub(1) * self.normal_chunk_len
261 + self
262 .last_chunk
263 .lock()
264 .expect("Self::push_last can not poison this lock")
265 .len()
266 }
267
268 fn pop(&self) -> Option<Vec<Sample>> {
269 self.inner.pop() // removes element that was inserted first
270 }
271
272 fn push_last(&self, mut samples: Vec<Sample>) {
273 let mut last_chunk = self
274 .last_chunk
275 .lock()
276 .expect("Self::len can not poison this lock");
277 std::mem::swap(&mut *last_chunk, &mut samples);
278 }
279
280 fn push_normal(&self, samples: Vec<Sample>) {
281 let _pushed_out_of_ringbuf = self.inner.force_push(samples);
282 }
283}
284
285/// constant source, only works on a single span
286pub struct ProcessBuffer<const N: usize, S, F>
287where
288 S: Source + Sized,
289 F: FnMut(&mut [Sample; N]),
290{
291 inner: S,
292 callback: F,
293 /// Buffer used for both input and output.
294 buffer: [Sample; N],
295 /// Next already processed sample is at this index
296 /// in buffer.
297 ///
298 /// If this is equal to the length of the buffer we have no more samples and
299 /// we must get new ones and process them
300 next: usize,
301}
302
303impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
304where
305 S: Source + Sized,
306 F: FnMut(&mut [Sample; N]),
307{
308 type Item = Sample;
309
310 fn next(&mut self) -> Option<Self::Item> {
311 self.next += 1;
312 if self.next < self.buffer.len() {
313 let sample = self.buffer[self.next];
314 return Some(sample);
315 }
316
317 for sample in &mut self.buffer {
318 *sample = self.inner.next()?
319 }
320 (self.callback)(&mut self.buffer);
321
322 self.next = 0;
323 Some(self.buffer[0])
324 }
325
326 fn size_hint(&self) -> (usize, Option<usize>) {
327 self.inner.size_hint()
328 }
329}
330
331impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
332where
333 S: Source + Sized,
334 F: FnMut(&mut [Sample; N]),
335{
336 fn current_span_len(&self) -> Option<usize> {
337 None
338 }
339
340 fn channels(&self) -> rodio::ChannelCount {
341 self.inner.channels()
342 }
343
344 fn sample_rate(&self) -> rodio::SampleRate {
345 self.inner.sample_rate()
346 }
347
348 fn total_duration(&self) -> Option<std::time::Duration> {
349 self.inner.total_duration()
350 }
351}
352
353/// constant source, only works on a single span
354pub struct InspectBuffer<const N: usize, S, F>
355where
356 S: Source + Sized,
357 F: FnMut(&[Sample; N]),
358{
359 inner: S,
360 callback: F,
361 /// Stores already emitted samples, once its full we call the callback.
362 buffer: [Sample; N],
363 /// Next free element in buffer. If this is equal to the buffer length
364 /// we have no more free lements.
365 free: usize,
366}
367
368impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
369where
370 S: Source + Sized,
371 F: FnMut(&[Sample; N]),
372{
373 type Item = Sample;
374
375 fn next(&mut self) -> Option<Self::Item> {
376 let Some(sample) = self.inner.next() else {
377 return None;
378 };
379
380 self.buffer[self.free] = sample;
381 self.free += 1;
382
383 if self.free == self.buffer.len() {
384 (self.callback)(&self.buffer);
385 self.free = 0
386 }
387
388 Some(sample)
389 }
390
391 fn size_hint(&self) -> (usize, Option<usize>) {
392 self.inner.size_hint()
393 }
394}
395
396impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
397where
398 S: Source + Sized,
399 F: FnMut(&[Sample; N]),
400{
401 fn current_span_len(&self) -> Option<usize> {
402 None
403 }
404
405 fn channels(&self) -> rodio::ChannelCount {
406 self.inner.channels()
407 }
408
409 fn sample_rate(&self) -> rodio::SampleRate {
410 self.inner.sample_rate()
411 }
412
413 fn total_duration(&self) -> Option<std::time::Duration> {
414 self.inner.total_duration()
415 }
416}
417
418/// constant source, only works on a single span
419#[derive(Debug)]
420pub struct Replayable<S: Source> {
421 inner: S,
422 buffer: Vec<Sample>,
423 chunk_size: usize,
424 tx: Arc<ReplayQueue>,
425 is_active: Arc<AtomicBool>,
426}
427
428impl<S: Source> Iterator for Replayable<S> {
429 type Item = Sample;
430
431 fn next(&mut self) -> Option<Self::Item> {
432 if let Some(sample) = self.inner.next() {
433 self.buffer.push(sample);
434 // If the buffer is full send it
435 if self.buffer.len() == self.chunk_size {
436 self.tx.push_normal(std::mem::take(&mut self.buffer));
437 }
438 Some(sample)
439 } else {
440 let last_chunk = std::mem::take(&mut self.buffer);
441 self.tx.push_last(last_chunk);
442 self.is_active.store(false, Ordering::Relaxed);
443 None
444 }
445 }
446
447 fn size_hint(&self) -> (usize, Option<usize>) {
448 self.inner.size_hint()
449 }
450}
451
452impl<S: Source> Source for Replayable<S> {
453 fn current_span_len(&self) -> Option<usize> {
454 self.inner.current_span_len()
455 }
456
457 fn channels(&self) -> ChannelCount {
458 self.inner.channels()
459 }
460
461 fn sample_rate(&self) -> SampleRate {
462 self.inner.sample_rate()
463 }
464
465 fn total_duration(&self) -> Option<Duration> {
466 self.inner.total_duration()
467 }
468}
469
470/// constant source, only works on a single span
471#[derive(Debug)]
472pub struct Replay {
473 rx: Arc<ReplayQueue>,
474 buffer: std::vec::IntoIter<Sample>,
475 sleep_duration: Duration,
476 sample_rate: SampleRate,
477 channel_count: ChannelCount,
478 source_is_active: Arc<AtomicBool>,
479}
480
481impl Replay {
482 pub fn source_is_active(&self) -> bool {
483 // - source could return None and not drop
484 // - source could be dropped before returning None
485 self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
486 }
487
488 /// Duration of what is in the buffer and can be returned without blocking.
489 pub fn duration_ready(&self) -> Duration {
490 let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
491
492 let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
493 Duration::from_secs_f64(seconds_queued)
494 }
495
496 /// Number of samples in the buffer and can be returned without blocking.
497 pub fn samples_ready(&self) -> usize {
498 self.rx.len() + self.buffer.len()
499 }
500}
501
502impl Iterator for Replay {
503 type Item = Sample;
504
505 fn next(&mut self) -> Option<Self::Item> {
506 if let Some(sample) = self.buffer.next() {
507 return Some(sample);
508 }
509
510 loop {
511 if let Some(new_buffer) = self.rx.pop() {
512 self.buffer = new_buffer.into_iter();
513 return self.buffer.next();
514 }
515
516 if !self.source_is_active() {
517 return None;
518 }
519
520 // The queue does not support blocking on a next item. We want this queue as it
521 // is quite fast and provides a fixed size. We know how many samples are in a
522 // buffer so if we do not get one now we must be getting one after `sleep_duration`.
523 std::thread::sleep(self.sleep_duration);
524 }
525 }
526
527 fn size_hint(&self) -> (usize, Option<usize>) {
528 ((self.rx.len() + self.buffer.len()), None)
529 }
530}
531
532impl Source for Replay {
533 fn current_span_len(&self) -> Option<usize> {
534 None // source is not compatible with spans
535 }
536
537 fn channels(&self) -> ChannelCount {
538 self.channel_count
539 }
540
541 fn sample_rate(&self) -> SampleRate {
542 self.sample_rate
543 }
544
545 fn total_duration(&self) -> Option<Duration> {
546 None
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use rodio::{nz, static_buffer::StaticSamplesBuffer};
553
554 use super::*;
555
556 const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
557
558 fn test_source() -> StaticSamplesBuffer {
559 StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
560 }
561
562 mod process_buffer {
563 use super::*;
564
565 #[test]
566 fn callback_gets_all_samples() {
567 let input = test_source();
568
569 let _ = input
570 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
571 .count();
572 }
573 #[test]
574 fn callback_modifies_yielded() {
575 let input = test_source();
576
577 let yielded: Vec<_> = input
578 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
579 for sample in buffer {
580 *sample += 1.0;
581 }
582 })
583 .collect();
584 assert_eq!(
585 yielded,
586 SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
587 )
588 }
589 #[test]
590 fn source_truncates_to_whole_buffers() {
591 let input = test_source();
592
593 let yielded = input
594 .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
595 .count();
596 assert_eq!(yielded, 3)
597 }
598 }
599
600 mod inspect_buffer {
601 use super::*;
602
603 #[test]
604 fn callback_gets_all_samples() {
605 let input = test_source();
606
607 let _ = input
608 .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
609 .count();
610 }
611 #[test]
612 fn source_does_not_truncate() {
613 let input = test_source();
614
615 let yielded = input
616 .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
617 .count();
618 assert_eq!(yielded, SAMPLES.len())
619 }
620 }
621
622 mod instant_replay {
623 use super::*;
624
625 #[test]
626 fn continues_after_history() {
627 let input = test_source();
628
629 let (mut replay, mut source) = input
630 .replayable(Duration::from_secs(3))
631 .expect("longer than 100ms");
632
633 source.by_ref().take(3).count();
634 let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
635 assert_eq!(&yielded, &SAMPLES[0..3],);
636
637 source.count();
638 let yielded: Vec<Sample> = replay.collect();
639 assert_eq!(&yielded, &SAMPLES[3..5],);
640 }
641
642 #[test]
643 fn keeps_only_latest() {
644 let input = test_source();
645
646 let (mut replay, mut source) = input
647 .replayable(Duration::from_secs(2))
648 .expect("longer than 100ms");
649
650 source.by_ref().take(5).count(); // get all items but do not end the source
651 let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
652 assert_eq!(&yielded, &SAMPLES[3..5]);
653 source.count(); // exhaust source
654 assert_eq!(replay.next(), None);
655 }
656
657 #[test]
658 fn keeps_correct_amount_of_seconds() {
659 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
660
661 let (replay, mut source) = input
662 .replayable(Duration::from_secs(2))
663 .expect("longer than 100ms");
664
665 // exhaust but do not yet end source
666 source.by_ref().take(40_000).count();
667
668 // take all samples we can without blocking
669 let ready = replay.samples_ready();
670 let n_yielded = replay.take_samples(ready).count();
671
672 let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
673 let margin = 16_000 / 10; // 100ms
674 assert!(n_yielded as u32 >= max - margin);
675 }
676
677 #[test]
678 fn samples_ready() {
679 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
680 let (mut replay, source) = input
681 .replayable(Duration::from_secs(2))
682 .expect("longer than 100ms");
683 assert_eq!(replay.by_ref().samples_ready(), 0);
684
685 source.take(8000).count(); // half a second
686 let margin = 16_000 / 10; // 100ms
687 let ready = replay.samples_ready();
688 assert!(ready >= 8000 - margin);
689 }
690 }
691}