1use std::{num::NonZero, time::Duration};
2
3use denoise::{Denoiser, DenoiserError};
4use log::warn;
5use rodio::{ChannelCount, Sample, SampleRate, Source, conversions::ChannelCountConverter, nz};
6
7use crate::rodio_ext::resample::FixedResampler;
8pub use replayable::{Replay, ReplayDurationTooShort, Replayable};
9
10mod replayable;
11mod resample;
12
13const MAX_CHANNELS: usize = 8;
14
15// These all require constant sources (so the span is infinitely long)
16// this is not guaranteed by rodio however we know it to be true in all our
17// applications. Rodio desperately needs a constant source concept.
18pub trait RodioExt: Source + Sized {
19 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
20 where
21 F: FnMut(&mut [Sample; N]);
22 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
23 where
24 F: FnMut(&[Sample; N]);
25 fn replayable(
26 self,
27 duration: Duration,
28 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
29 fn take_samples(self, n: usize) -> TakeSamples<Self>;
30 fn denoise(self) -> Result<Denoiser<Self>, DenoiserError>;
31 fn constant_params(
32 self,
33 channel_count: ChannelCount,
34 sample_rate: SampleRate,
35 ) -> ConstantChannelCount<FixedResampler<Self>>;
36 fn constant_samplerate(self, sample_rate: SampleRate) -> FixedResampler<Self>;
37 fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self>;
38}
39
40impl<S: Source> RodioExt for S {
41 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
42 where
43 F: FnMut(&mut [Sample; N]),
44 {
45 ProcessBuffer {
46 inner: self,
47 callback,
48 buffer: [0.0; N],
49 next: N,
50 }
51 }
52 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
53 where
54 F: FnMut(&[Sample; N]),
55 {
56 InspectBuffer {
57 inner: self,
58 callback,
59 buffer: [0.0; N],
60 free: 0,
61 }
62 }
63 /// Maintains a live replay with a history of at least `duration` seconds.
64 ///
65 /// Note:
66 /// History can be 100ms longer if the source drops before or while the
67 /// replay is being read
68 ///
69 /// # Errors
70 /// If duration is smaller than 100ms
71 fn replayable(
72 self,
73 duration: Duration,
74 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
75 replayable::replayable(self, duration)
76 }
77 fn take_samples(self, n: usize) -> TakeSamples<S> {
78 TakeSamples {
79 inner: self,
80 left_to_take: n,
81 }
82 }
83 fn denoise(self) -> Result<Denoiser<Self>, DenoiserError> {
84 let res = Denoiser::try_new(self);
85 res
86 }
87 fn constant_params(
88 self,
89 channel_count: ChannelCount,
90 sample_rate: SampleRate,
91 ) -> ConstantChannelCount<FixedResampler<Self>> {
92 ConstantChannelCount::new(self.constant_samplerate(sample_rate), channel_count)
93 }
94 fn constant_samplerate(self, sample_rate: SampleRate) -> FixedResampler<Self> {
95 FixedResampler::new(self, sample_rate)
96 }
97 fn possibly_disconnected_channels_to_mono(self) -> ToMono<Self> {
98 ToMono::new(self)
99 }
100}
101
102pub struct ConstantChannelCount<S: Source> {
103 inner: ChannelCountConverter<S>,
104 channels: ChannelCount,
105 sample_rate: SampleRate,
106}
107
108impl<S: Source> ConstantChannelCount<S> {
109 fn new(source: S, target_channels: ChannelCount) -> Self {
110 let input_channels = source.channels();
111 let sample_rate = source.sample_rate();
112 let inner = ChannelCountConverter::new(source, input_channels, target_channels);
113 Self {
114 sample_rate,
115 inner,
116 channels: target_channels,
117 }
118 }
119}
120
121impl<S: Source> Iterator for ConstantChannelCount<S> {
122 type Item = rodio::Sample;
123
124 fn next(&mut self) -> Option<Self::Item> {
125 self.inner.next()
126 }
127
128 fn size_hint(&self) -> (usize, Option<usize>) {
129 self.inner.size_hint()
130 }
131}
132
133impl<S: Source> Source for ConstantChannelCount<S> {
134 fn current_span_len(&self) -> Option<usize> {
135 None
136 }
137
138 fn channels(&self) -> ChannelCount {
139 self.channels
140 }
141
142 fn sample_rate(&self) -> SampleRate {
143 self.sample_rate
144 }
145
146 fn total_duration(&self) -> Option<Duration> {
147 None // not supported (not used by us)
148 }
149}
150
151const TYPICAL_NOISE_FLOOR: Sample = 1e-3;
152
153/// constant source, only works on a single span
154pub struct ToMono<S> {
155 inner: S,
156 input_channel_count: ChannelCount,
157 connected_channels: ChannelCount,
158 /// running mean of second channel 'volume'
159 means: [f32; MAX_CHANNELS],
160}
161impl<S: Source> ToMono<S> {
162 fn new(input: S) -> Self {
163 let channels = input
164 .channels()
165 .min(const { NonZero::<u16>::new(MAX_CHANNELS as u16).unwrap() });
166 if channels < input.channels() {
167 warn!("Ignoring input channels {}..", channels.get());
168 }
169
170 Self {
171 connected_channels: channels,
172 input_channel_count: channels,
173 inner: input,
174 means: [TYPICAL_NOISE_FLOOR; MAX_CHANNELS],
175 }
176 }
177}
178
179impl<S: Source> Source for ToMono<S> {
180 fn current_span_len(&self) -> Option<usize> {
181 None
182 }
183
184 fn channels(&self) -> ChannelCount {
185 rodio::nz!(1)
186 }
187
188 fn sample_rate(&self) -> SampleRate {
189 self.inner.sample_rate()
190 }
191
192 fn total_duration(&self) -> Option<Duration> {
193 self.inner.total_duration()
194 }
195}
196
197fn update_mean(mean: &mut f32, sample: Sample) {
198 const HISTORY: f32 = 500.0;
199 *mean *= (HISTORY - 1.0) / HISTORY;
200 *mean += sample.abs() / HISTORY;
201}
202
203impl<S: Source> Iterator for ToMono<S> {
204 type Item = Sample;
205
206 fn next(&mut self) -> Option<Self::Item> {
207 let mut mono_sample = 0f32;
208 let mut active_channels = 0;
209 for channel in 0..self.input_channel_count.get() as usize {
210 let sample = self.inner.next()?;
211 mono_sample += sample;
212
213 update_mean(&mut self.means[channel], sample);
214 if self.means[channel] > TYPICAL_NOISE_FLOOR / 10.0 {
215 active_channels += 1;
216 }
217 }
218 mono_sample /= self.connected_channels.get() as f32;
219 self.connected_channels = NonZero::new(active_channels).unwrap_or(nz!(1));
220
221 Some(mono_sample)
222 }
223}
224
225/// constant source, only works on a single span
226pub struct TakeSamples<S> {
227 inner: S,
228 left_to_take: usize,
229}
230
231impl<S: Source> Iterator for TakeSamples<S> {
232 type Item = Sample;
233
234 fn next(&mut self) -> Option<Self::Item> {
235 if self.left_to_take == 0 {
236 None
237 } else {
238 self.left_to_take -= 1;
239 self.inner.next()
240 }
241 }
242
243 fn size_hint(&self) -> (usize, Option<usize>) {
244 (0, Some(self.left_to_take))
245 }
246}
247
248impl<S: Source> Source for TakeSamples<S> {
249 fn current_span_len(&self) -> Option<usize> {
250 None // does not support spans
251 }
252
253 fn channels(&self) -> ChannelCount {
254 self.inner.channels()
255 }
256
257 fn sample_rate(&self) -> SampleRate {
258 self.inner.sample_rate()
259 }
260
261 fn total_duration(&self) -> Option<Duration> {
262 Some(Duration::from_secs_f64(
263 self.left_to_take as f64
264 / self.sample_rate().get() as f64
265 / self.channels().get() as f64,
266 ))
267 }
268}
269
270/// constant source, only works on a single span
271pub struct ProcessBuffer<const N: usize, S, F>
272where
273 S: Source + Sized,
274 F: FnMut(&mut [Sample; N]),
275{
276 inner: S,
277 callback: F,
278 /// Buffer used for both input and output.
279 buffer: [Sample; N],
280 /// Next already processed sample is at this index
281 /// in buffer.
282 ///
283 /// If this is equal to the length of the buffer we have no more samples and
284 /// we must get new ones and process them
285 next: usize,
286}
287
288impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
289where
290 S: Source + Sized,
291 F: FnMut(&mut [Sample; N]),
292{
293 type Item = Sample;
294
295 fn next(&mut self) -> Option<Self::Item> {
296 self.next += 1;
297 if self.next < self.buffer.len() {
298 let sample = self.buffer[self.next];
299 return Some(sample);
300 }
301
302 for sample in &mut self.buffer {
303 *sample = self.inner.next()?
304 }
305 (self.callback)(&mut self.buffer);
306
307 self.next = 0;
308 Some(self.buffer[0])
309 }
310
311 fn size_hint(&self) -> (usize, Option<usize>) {
312 self.inner.size_hint()
313 }
314}
315
316impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
317where
318 S: Source + Sized,
319 F: FnMut(&mut [Sample; N]),
320{
321 fn current_span_len(&self) -> Option<usize> {
322 None
323 }
324
325 fn channels(&self) -> rodio::ChannelCount {
326 self.inner.channels()
327 }
328
329 fn sample_rate(&self) -> rodio::SampleRate {
330 self.inner.sample_rate()
331 }
332
333 fn total_duration(&self) -> Option<std::time::Duration> {
334 self.inner.total_duration()
335 }
336}
337
338/// constant source, only works on a single span
339pub struct InspectBuffer<const N: usize, S, F>
340where
341 S: Source + Sized,
342 F: FnMut(&[Sample; N]),
343{
344 inner: S,
345 callback: F,
346 /// Stores already emitted samples, once its full we call the callback.
347 buffer: [Sample; N],
348 /// Next free element in buffer. If this is equal to the buffer length
349 /// we have no more free lements.
350 free: usize,
351}
352
353impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
354where
355 S: Source + Sized,
356 F: FnMut(&[Sample; N]),
357{
358 type Item = Sample;
359
360 fn next(&mut self) -> Option<Self::Item> {
361 let Some(sample) = self.inner.next() else {
362 return None;
363 };
364
365 self.buffer[self.free] = sample;
366 self.free += 1;
367
368 if self.free == self.buffer.len() {
369 (self.callback)(&self.buffer);
370 self.free = 0
371 }
372
373 Some(sample)
374 }
375
376 fn size_hint(&self) -> (usize, Option<usize>) {
377 self.inner.size_hint()
378 }
379}
380
381impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
382where
383 S: Source + Sized,
384 F: FnMut(&[Sample; N]),
385{
386 fn current_span_len(&self) -> Option<usize> {
387 None
388 }
389
390 fn channels(&self) -> rodio::ChannelCount {
391 self.inner.channels()
392 }
393
394 fn sample_rate(&self) -> rodio::SampleRate {
395 self.inner.sample_rate()
396 }
397
398 fn total_duration(&self) -> Option<std::time::Duration> {
399 self.inner.total_duration()
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use rodio::{nz, static_buffer::StaticSamplesBuffer};
406
407 use super::*;
408
409 pub const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
410
411 pub fn test_source() -> StaticSamplesBuffer {
412 StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
413 }
414
415 mod process_buffer {
416 use super::*;
417
418 #[test]
419 fn callback_gets_all_samples() {
420 let input = test_source();
421
422 let _ = input
423 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
424 .count();
425 }
426 #[test]
427 fn callback_modifies_yielded() {
428 let input = test_source();
429
430 let yielded: Vec<_> = input
431 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
432 for sample in buffer {
433 *sample += 1.0;
434 }
435 })
436 .collect();
437 assert_eq!(
438 yielded,
439 SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
440 )
441 }
442 #[test]
443 fn source_truncates_to_whole_buffers() {
444 let input = test_source();
445
446 let yielded = input
447 .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
448 .count();
449 assert_eq!(yielded, 3)
450 }
451 }
452
453 mod inspect_buffer {
454 use super::*;
455
456 #[test]
457 fn callback_gets_all_samples() {
458 let input = test_source();
459
460 let _ = input
461 .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
462 .count();
463 }
464 #[test]
465 fn source_does_not_truncate() {
466 let input = test_source();
467
468 let yielded = input
469 .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
470 .count();
471 assert_eq!(yielded, SAMPLES.len())
472 }
473 }
474}