1use std::{
2 sync::{
3 Arc, Mutex,
4 atomic::{AtomicBool, Ordering},
5 },
6 time::Duration,
7};
8
9use crossbeam::queue::ArrayQueue;
10use rodio::{ChannelCount, Sample, SampleRate, Source};
11
12#[derive(Debug, thiserror::Error)]
13#[error("Replay duration is too short must be >= 100ms")]
14pub struct ReplayDurationTooShort;
15
16pub trait RodioExt: Source + Sized {
17 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
18 where
19 F: FnMut(&mut [Sample; N]);
20 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
21 where
22 F: FnMut(&[Sample; N]);
23 fn replayable(
24 self,
25 duration: Duration,
26 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
27 fn take_samples(self, n: usize) -> TakeSamples<Self>;
28}
29
30impl<S: Source> RodioExt for S {
31 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
32 where
33 F: FnMut(&mut [Sample; N]),
34 {
35 ProcessBuffer {
36 inner: self,
37 callback,
38 buffer: [0.0; N],
39 next: N,
40 }
41 }
42 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
43 where
44 F: FnMut(&[Sample; N]),
45 {
46 InspectBuffer {
47 inner: self,
48 callback,
49 buffer: [0.0; N],
50 free: 0,
51 }
52 }
53 /// Maintains a live replay with a history of at least `duration` seconds.
54 ///
55 /// Note:
56 /// History can be 100ms longer if the source drops before or while the
57 /// replay is being read
58 ///
59 /// # Errors
60 /// If duration is smaller than 100ms
61 fn replayable(
62 self,
63 duration: Duration,
64 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
65 if duration < Duration::from_millis(100) {
66 return Err(ReplayDurationTooShort);
67 }
68
69 let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
70 let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
71 let samples_to_queue =
72 (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
73
74 let chunk_size =
75 (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
76 let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
77
78 let is_active = Arc::new(AtomicBool::new(true));
79 let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
80 Ok((
81 Replay {
82 rx: Arc::clone(&queue),
83 buffer: Vec::new().into_iter(),
84 sleep_duration: duration / 2,
85 sample_rate: self.sample_rate(),
86 channel_count: self.channels(),
87 source_is_active: is_active.clone(),
88 },
89 Replayable {
90 tx: queue,
91 inner: self,
92 buffer: Vec::with_capacity(chunk_size),
93 chunk_size,
94 is_active,
95 },
96 ))
97 }
98 fn take_samples(self, n: usize) -> TakeSamples<S> {
99 TakeSamples {
100 inner: self,
101 left_to_take: n,
102 }
103 }
104}
105
106pub struct TakeSamples<S> {
107 inner: S,
108 left_to_take: usize,
109}
110
111impl<S: Source> Iterator for TakeSamples<S> {
112 type Item = Sample;
113
114 fn next(&mut self) -> Option<Self::Item> {
115 if self.left_to_take == 0 {
116 None
117 } else {
118 self.left_to_take -= 1;
119 self.inner.next()
120 }
121 }
122
123 fn size_hint(&self) -> (usize, Option<usize>) {
124 (0, Some(self.left_to_take))
125 }
126}
127
128impl<S: Source> Source for TakeSamples<S> {
129 fn current_span_len(&self) -> Option<usize> {
130 None // does not support spans
131 }
132
133 fn channels(&self) -> ChannelCount {
134 self.inner.channels()
135 }
136
137 fn sample_rate(&self) -> SampleRate {
138 self.inner.sample_rate()
139 }
140
141 fn total_duration(&self) -> Option<Duration> {
142 Some(Duration::from_secs_f64(
143 self.left_to_take as f64
144 / self.sample_rate().get() as f64
145 / self.channels().get() as f64,
146 ))
147 }
148}
149
150#[derive(Debug)]
151struct ReplayQueue {
152 inner: ArrayQueue<Vec<Sample>>,
153 normal_chunk_len: usize,
154 /// The last chunk in the queue may be smaller than
155 /// the normal chunk size. This is always equal to the
156 /// size of the last element in the queue.
157 /// (so normally chunk_size)
158 last_chunk: Mutex<Vec<Sample>>,
159}
160
161impl ReplayQueue {
162 fn new(queue_len: usize, chunk_size: usize) -> Self {
163 Self {
164 inner: ArrayQueue::new(queue_len),
165 normal_chunk_len: chunk_size,
166 last_chunk: Mutex::new(Vec::new()),
167 }
168 }
169 /// Returns the length in samples
170 fn len(&self) -> usize {
171 self.inner.len().saturating_sub(1) * self.normal_chunk_len
172 + self
173 .last_chunk
174 .lock()
175 .expect("Self::push_last can not poison this lock")
176 .len()
177 }
178
179 fn pop(&self) -> Option<Vec<Sample>> {
180 self.inner.pop() // removes element that was inserted first
181 }
182
183 fn push_last(&self, mut samples: Vec<Sample>) {
184 let mut last_chunk = self
185 .last_chunk
186 .lock()
187 .expect("Self::len can not poison this lock");
188 std::mem::swap(&mut *last_chunk, &mut samples);
189 }
190
191 fn push_normal(&self, samples: Vec<Sample>) {
192 let _pushed_out_of_ringbuf = self.inner.force_push(samples);
193 }
194}
195
196pub struct ProcessBuffer<const N: usize, S, F>
197where
198 S: Source + Sized,
199 F: FnMut(&mut [Sample; N]),
200{
201 inner: S,
202 callback: F,
203 /// Buffer used for both input and output.
204 buffer: [Sample; N],
205 /// Next already processed sample is at this index
206 /// in buffer.
207 ///
208 /// If this is equal to the length of the buffer we have no more samples and
209 /// we must get new ones and process them
210 next: usize,
211}
212
213impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
214where
215 S: Source + Sized,
216 F: FnMut(&mut [Sample; N]),
217{
218 type Item = Sample;
219
220 fn next(&mut self) -> Option<Self::Item> {
221 self.next += 1;
222 if self.next < self.buffer.len() {
223 let sample = self.buffer[self.next];
224 return Some(sample);
225 }
226
227 for sample in &mut self.buffer {
228 *sample = self.inner.next()?
229 }
230 (self.callback)(&mut self.buffer);
231
232 self.next = 0;
233 Some(self.buffer[0])
234 }
235
236 fn size_hint(&self) -> (usize, Option<usize>) {
237 self.inner.size_hint()
238 }
239}
240
241impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
242where
243 S: Source + Sized,
244 F: FnMut(&mut [Sample; N]),
245{
246 fn current_span_len(&self) -> Option<usize> {
247 None
248 }
249
250 fn channels(&self) -> rodio::ChannelCount {
251 self.inner.channels()
252 }
253
254 fn sample_rate(&self) -> rodio::SampleRate {
255 self.inner.sample_rate()
256 }
257
258 fn total_duration(&self) -> Option<std::time::Duration> {
259 self.inner.total_duration()
260 }
261}
262
263pub struct InspectBuffer<const N: usize, S, F>
264where
265 S: Source + Sized,
266 F: FnMut(&[Sample; N]),
267{
268 inner: S,
269 callback: F,
270 /// Stores already emitted samples, once its full we call the callback.
271 buffer: [Sample; N],
272 /// Next free element in buffer. If this is equal to the buffer length
273 /// we have no more free lements.
274 free: usize,
275}
276
277impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
278where
279 S: Source + Sized,
280 F: FnMut(&[Sample; N]),
281{
282 type Item = Sample;
283
284 fn next(&mut self) -> Option<Self::Item> {
285 let Some(sample) = self.inner.next() else {
286 return None;
287 };
288
289 self.buffer[self.free] = sample;
290 self.free += 1;
291
292 if self.free == self.buffer.len() {
293 (self.callback)(&self.buffer);
294 self.free = 0
295 }
296
297 Some(sample)
298 }
299
300 fn size_hint(&self) -> (usize, Option<usize>) {
301 self.inner.size_hint()
302 }
303}
304
305impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
306where
307 S: Source + Sized,
308 F: FnMut(&[Sample; N]),
309{
310 fn current_span_len(&self) -> Option<usize> {
311 None
312 }
313
314 fn channels(&self) -> rodio::ChannelCount {
315 self.inner.channels()
316 }
317
318 fn sample_rate(&self) -> rodio::SampleRate {
319 self.inner.sample_rate()
320 }
321
322 fn total_duration(&self) -> Option<std::time::Duration> {
323 self.inner.total_duration()
324 }
325}
326
327#[derive(Debug)]
328pub struct Replayable<S: Source> {
329 inner: S,
330 buffer: Vec<Sample>,
331 chunk_size: usize,
332 tx: Arc<ReplayQueue>,
333 is_active: Arc<AtomicBool>,
334}
335
336impl<S: Source> Iterator for Replayable<S> {
337 type Item = Sample;
338
339 fn next(&mut self) -> Option<Self::Item> {
340 if let Some(sample) = self.inner.next() {
341 self.buffer.push(sample);
342 // If the buffer is full send it
343 if self.buffer.len() == self.chunk_size {
344 self.tx.push_normal(std::mem::take(&mut self.buffer));
345 }
346 Some(sample)
347 } else {
348 let last_chunk = std::mem::take(&mut self.buffer);
349 self.tx.push_last(last_chunk);
350 self.is_active.store(false, Ordering::Relaxed);
351 None
352 }
353 }
354
355 fn size_hint(&self) -> (usize, Option<usize>) {
356 self.inner.size_hint()
357 }
358}
359
360impl<S: Source> Source for Replayable<S> {
361 fn current_span_len(&self) -> Option<usize> {
362 self.inner.current_span_len()
363 }
364
365 fn channels(&self) -> ChannelCount {
366 self.inner.channels()
367 }
368
369 fn sample_rate(&self) -> SampleRate {
370 self.inner.sample_rate()
371 }
372
373 fn total_duration(&self) -> Option<Duration> {
374 self.inner.total_duration()
375 }
376}
377
378#[derive(Debug)]
379pub struct Replay {
380 rx: Arc<ReplayQueue>,
381 buffer: std::vec::IntoIter<Sample>,
382 sleep_duration: Duration,
383 sample_rate: SampleRate,
384 channel_count: ChannelCount,
385 source_is_active: Arc<AtomicBool>,
386}
387
388impl Replay {
389 pub fn source_is_active(&self) -> bool {
390 // - source could return None and not drop
391 // - source could be dropped before returning None
392 self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
393 }
394
395 /// Duration of what is in the buffer and can be returned without blocking.
396 pub fn duration_ready(&self) -> Duration {
397 let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
398
399 let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
400 Duration::from_secs_f64(seconds_queued)
401 }
402
403 /// Number of samples in the buffer and can be returned without blocking.
404 pub fn samples_ready(&self) -> usize {
405 self.rx.len() + self.buffer.len()
406 }
407}
408
409impl Iterator for Replay {
410 type Item = Sample;
411
412 fn next(&mut self) -> Option<Self::Item> {
413 if let Some(sample) = self.buffer.next() {
414 return Some(sample);
415 }
416
417 loop {
418 if let Some(new_buffer) = self.rx.pop() {
419 self.buffer = new_buffer.into_iter();
420 return self.buffer.next();
421 }
422
423 if !self.source_is_active() {
424 return None;
425 }
426
427 // The queue does not support blocking on a next item. We want this queue as it
428 // is quite fast and provides a fixed size. We know how many samples are in a
429 // buffer so if we do not get one now we must be getting one after `sleep_duration`.
430 std::thread::sleep(self.sleep_duration);
431 }
432 }
433
434 fn size_hint(&self) -> (usize, Option<usize>) {
435 ((self.rx.len() + self.buffer.len()), None)
436 }
437}
438
439impl Source for Replay {
440 fn current_span_len(&self) -> Option<usize> {
441 None // source is not compatible with spans
442 }
443
444 fn channels(&self) -> ChannelCount {
445 self.channel_count
446 }
447
448 fn sample_rate(&self) -> SampleRate {
449 self.sample_rate
450 }
451
452 fn total_duration(&self) -> Option<Duration> {
453 None
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use rodio::{nz, static_buffer::StaticSamplesBuffer};
460
461 use super::*;
462
463 const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
464
465 fn test_source() -> StaticSamplesBuffer {
466 StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
467 }
468
469 mod process_buffer {
470 use super::*;
471
472 #[test]
473 fn callback_gets_all_samples() {
474 let input = test_source();
475
476 let _ = input
477 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
478 .count();
479 }
480 #[test]
481 fn callback_modifies_yielded() {
482 let input = test_source();
483
484 let yielded: Vec<_> = input
485 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
486 for sample in buffer {
487 *sample += 1.0;
488 }
489 })
490 .collect();
491 assert_eq!(
492 yielded,
493 SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
494 )
495 }
496 #[test]
497 fn source_truncates_to_whole_buffers() {
498 let input = test_source();
499
500 let yielded = input
501 .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
502 .count();
503 assert_eq!(yielded, 3)
504 }
505 }
506
507 mod inspect_buffer {
508 use super::*;
509
510 #[test]
511 fn callback_gets_all_samples() {
512 let input = test_source();
513
514 let _ = input
515 .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
516 .count();
517 }
518 #[test]
519 fn source_does_not_truncate() {
520 let input = test_source();
521
522 let yielded = input
523 .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
524 .count();
525 assert_eq!(yielded, SAMPLES.len())
526 }
527 }
528
529 mod instant_replay {
530 use super::*;
531
532 #[test]
533 fn continues_after_history() {
534 let input = test_source();
535
536 let (mut replay, mut source) = input
537 .replayable(Duration::from_secs(3))
538 .expect("longer than 100ms");
539
540 source.by_ref().take(3).count();
541 let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
542 assert_eq!(&yielded, &SAMPLES[0..3],);
543
544 source.count();
545 let yielded: Vec<Sample> = replay.collect();
546 assert_eq!(&yielded, &SAMPLES[3..5],);
547 }
548
549 #[test]
550 fn keeps_only_latest() {
551 let input = test_source();
552
553 let (mut replay, mut source) = input
554 .replayable(Duration::from_secs(2))
555 .expect("longer than 100ms");
556
557 source.by_ref().take(5).count(); // get all items but do not end the source
558 let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
559 assert_eq!(&yielded, &SAMPLES[3..5]);
560 source.count(); // exhaust source
561 assert_eq!(replay.next(), None);
562 }
563
564 #[test]
565 fn keeps_correct_amount_of_seconds() {
566 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
567
568 let (replay, mut source) = input
569 .replayable(Duration::from_secs(2))
570 .expect("longer than 100ms");
571
572 // exhaust but do not yet end source
573 source.by_ref().take(40_000).count();
574
575 // take all samples we can without blocking
576 let ready = replay.samples_ready();
577 let n_yielded = replay.take_samples(ready).count();
578
579 let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
580 let margin = 16_000 / 10; // 100ms
581 assert!(n_yielded as u32 >= max - margin);
582 }
583
584 #[test]
585 fn samples_ready() {
586 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
587 let (mut replay, source) = input
588 .replayable(Duration::from_secs(2))
589 .expect("longer than 100ms");
590 assert_eq!(replay.by_ref().samples_ready(), 0);
591
592 source.take(8000).count(); // half a second
593 let margin = 16_000 / 10; // 100ms
594 let ready = replay.samples_ready();
595 assert!(ready >= 8000 - margin);
596 }
597 }
598}