1use std::{
2 sync::{
3 Arc, Mutex,
4 atomic::{AtomicBool, Ordering},
5 },
6 time::Duration,
7};
8
9use crossbeam::queue::ArrayQueue;
10use rodio::{ChannelCount, Sample, SampleRate, Source};
11
12#[derive(Debug)]
13pub struct ReplayDurationTooShort;
14
15pub trait RodioExt: Source + Sized {
16 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
17 where
18 F: FnMut(&mut [Sample; N]);
19 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
20 where
21 F: FnMut(&[Sample; N]);
22 fn replayable(
23 self,
24 duration: Duration,
25 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort>;
26 fn take_samples(self, n: usize) -> TakeSamples<Self>;
27}
28
29impl<S: Source> RodioExt for S {
30 fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
31 where
32 F: FnMut(&mut [Sample; N]),
33 {
34 ProcessBuffer {
35 inner: self,
36 callback,
37 buffer: [0.0; N],
38 next: N,
39 }
40 }
41 fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
42 where
43 F: FnMut(&[Sample; N]),
44 {
45 InspectBuffer {
46 inner: self,
47 callback,
48 buffer: [0.0; N],
49 free: 0,
50 }
51 }
52 /// Maintains a live replay with a history of at least `duration` seconds.
53 ///
54 /// Note:
55 /// History can be 100ms longer if the source drops before or while the
56 /// replay is being read
57 ///
58 /// # Errors
59 /// If duration is smaller then 100ms
60 fn replayable(
61 self,
62 duration: Duration,
63 ) -> Result<(Replay, Replayable<Self>), ReplayDurationTooShort> {
64 if duration < Duration::from_millis(100) {
65 return Err(ReplayDurationTooShort);
66 }
67
68 let samples_per_second = self.sample_rate().get() as usize * self.channels().get() as usize;
69 let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
70 let samples_to_queue =
71 (samples_to_queue as usize).next_multiple_of(self.channels().get().into());
72
73 let chunk_size =
74 (samples_per_second.div_ceil(10)).next_multiple_of(self.channels().get() as usize);
75 let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
76
77 let is_active = Arc::new(AtomicBool::new(true));
78 let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
79 Ok((
80 Replay {
81 rx: Arc::clone(&queue),
82 buffer: Vec::new().into_iter(),
83 sleep_duration: duration / 2,
84 sample_rate: self.sample_rate(),
85 channel_count: self.channels(),
86 source_is_active: is_active.clone(),
87 },
88 Replayable {
89 tx: queue,
90 inner: self,
91 buffer: Vec::with_capacity(chunk_size),
92 chunk_size,
93 is_active,
94 },
95 ))
96 }
97 fn take_samples(self, n: usize) -> TakeSamples<S> {
98 TakeSamples {
99 inner: self,
100 left_to_take: n,
101 }
102 }
103}
104
105pub struct TakeSamples<S> {
106 inner: S,
107 left_to_take: usize,
108}
109
110impl<S: Source> Iterator for TakeSamples<S> {
111 type Item = Sample;
112
113 fn next(&mut self) -> Option<Self::Item> {
114 if self.left_to_take == 0 {
115 None
116 } else {
117 self.left_to_take -= 1;
118 self.inner.next()
119 }
120 }
121
122 fn size_hint(&self) -> (usize, Option<usize>) {
123 (0, Some(self.left_to_take))
124 }
125}
126
127impl<S: Source> Source for TakeSamples<S> {
128 fn current_span_len(&self) -> Option<usize> {
129 None // does not support spans
130 }
131
132 fn channels(&self) -> ChannelCount {
133 self.inner.channels()
134 }
135
136 fn sample_rate(&self) -> SampleRate {
137 self.inner.sample_rate()
138 }
139
140 fn total_duration(&self) -> Option<Duration> {
141 Some(Duration::from_secs_f64(
142 self.left_to_take as f64
143 / self.sample_rate().get() as f64
144 / self.channels().get() as f64,
145 ))
146 }
147}
148
149#[derive(Debug)]
150struct ReplayQueue {
151 inner: ArrayQueue<Vec<Sample>>,
152 normal_chunk_len: usize,
153 /// The last chunk in the queue may be smaller then
154 /// the normal chunk size. This is always equal to the
155 /// size of the last element in the queue.
156 /// (so normally chunk_size)
157 last_chunk: Mutex<Vec<Sample>>,
158}
159
160impl ReplayQueue {
161 fn new(queue_len: usize, chunk_size: usize) -> Self {
162 Self {
163 inner: ArrayQueue::new(queue_len),
164 normal_chunk_len: chunk_size,
165 last_chunk: Mutex::new(Vec::new()),
166 }
167 }
168 /// Returns the length in samples
169 fn len(&self) -> usize {
170 self.inner.len().saturating_sub(1) * self.normal_chunk_len
171 + self
172 .last_chunk
173 .lock()
174 .expect("Self::push_last can not poison this lock")
175 .len()
176 }
177
178 fn pop(&self) -> Option<Vec<Sample>> {
179 self.inner.pop() // removes element that was inserted first
180 }
181
182 fn push_last(&self, mut samples: Vec<Sample>) {
183 let mut last_chunk = self
184 .last_chunk
185 .lock()
186 .expect("Self::len can not poison this lock");
187 std::mem::swap(&mut *last_chunk, &mut samples);
188 }
189
190 fn push_normal(&self, samples: Vec<Sample>) {
191 let _pushed_out_of_ringbuf = self.inner.force_push(samples);
192 }
193}
194
195pub struct ProcessBuffer<const N: usize, S, F>
196where
197 S: Source + Sized,
198 F: FnMut(&mut [Sample; N]),
199{
200 inner: S,
201 callback: F,
202 /// Buffer used for both input and output.
203 buffer: [Sample; N],
204 /// Next already processed sample is at this index
205 /// in buffer.
206 ///
207 /// If this is equal to the length of the buffer we have no more samples and
208 /// we must get new ones and process them
209 next: usize,
210}
211
212impl<const N: usize, S, F> Iterator for ProcessBuffer<N, S, F>
213where
214 S: Source + Sized,
215 F: FnMut(&mut [Sample; N]),
216{
217 type Item = Sample;
218
219 fn next(&mut self) -> Option<Self::Item> {
220 self.next += 1;
221 if self.next < self.buffer.len() {
222 let sample = self.buffer[self.next];
223 return Some(sample);
224 }
225
226 for sample in &mut self.buffer {
227 *sample = self.inner.next()?
228 }
229 (self.callback)(&mut self.buffer);
230
231 self.next = 0;
232 Some(self.buffer[0])
233 }
234
235 fn size_hint(&self) -> (usize, Option<usize>) {
236 self.inner.size_hint()
237 }
238}
239
240impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
241where
242 S: Source + Sized,
243 F: FnMut(&mut [Sample; N]),
244{
245 fn current_span_len(&self) -> Option<usize> {
246 None
247 }
248
249 fn channels(&self) -> rodio::ChannelCount {
250 self.inner.channels()
251 }
252
253 fn sample_rate(&self) -> rodio::SampleRate {
254 self.inner.sample_rate()
255 }
256
257 fn total_duration(&self) -> Option<std::time::Duration> {
258 self.inner.total_duration()
259 }
260}
261
262pub struct InspectBuffer<const N: usize, S, F>
263where
264 S: Source + Sized,
265 F: FnMut(&[Sample; N]),
266{
267 inner: S,
268 callback: F,
269 /// Stores already emitted samples, once its full we call the callback.
270 buffer: [Sample; N],
271 /// Next free element in buffer. If this is equal to the buffer length
272 /// we have no more free lements.
273 free: usize,
274}
275
276impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
277where
278 S: Source + Sized,
279 F: FnMut(&[Sample; N]),
280{
281 type Item = Sample;
282
283 fn next(&mut self) -> Option<Self::Item> {
284 let Some(sample) = self.inner.next() else {
285 return None;
286 };
287
288 self.buffer[self.free] = sample;
289 self.free += 1;
290
291 if self.free == self.buffer.len() {
292 (self.callback)(&self.buffer);
293 self.free = 0
294 }
295
296 Some(sample)
297 }
298
299 fn size_hint(&self) -> (usize, Option<usize>) {
300 self.inner.size_hint()
301 }
302}
303
304impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
305where
306 S: Source + Sized,
307 F: FnMut(&[Sample; N]),
308{
309 fn current_span_len(&self) -> Option<usize> {
310 None
311 }
312
313 fn channels(&self) -> rodio::ChannelCount {
314 self.inner.channels()
315 }
316
317 fn sample_rate(&self) -> rodio::SampleRate {
318 self.inner.sample_rate()
319 }
320
321 fn total_duration(&self) -> Option<std::time::Duration> {
322 self.inner.total_duration()
323 }
324}
325
326#[derive(Debug)]
327pub struct Replayable<S: Source> {
328 inner: S,
329 buffer: Vec<Sample>,
330 chunk_size: usize,
331 tx: Arc<ReplayQueue>,
332 is_active: Arc<AtomicBool>,
333}
334
335impl<S: Source> Iterator for Replayable<S> {
336 type Item = Sample;
337
338 fn next(&mut self) -> Option<Self::Item> {
339 if let Some(sample) = self.inner.next() {
340 self.buffer.push(sample);
341 if self.buffer.len() == self.chunk_size {
342 self.tx.push_normal(std::mem::take(&mut self.buffer));
343 }
344 Some(sample)
345 } else {
346 let last_chunk = std::mem::take(&mut self.buffer);
347 self.tx.push_last(last_chunk);
348 self.is_active.store(false, Ordering::Relaxed);
349 None
350 }
351 }
352
353 fn size_hint(&self) -> (usize, Option<usize>) {
354 self.inner.size_hint()
355 }
356}
357
358impl<S: Source> Source for Replayable<S> {
359 fn current_span_len(&self) -> Option<usize> {
360 self.inner.current_span_len()
361 }
362
363 fn channels(&self) -> ChannelCount {
364 self.inner.channels()
365 }
366
367 fn sample_rate(&self) -> SampleRate {
368 self.inner.sample_rate()
369 }
370
371 fn total_duration(&self) -> Option<Duration> {
372 self.inner.total_duration()
373 }
374}
375
376#[derive(Debug)]
377pub struct Replay {
378 rx: Arc<ReplayQueue>,
379 buffer: std::vec::IntoIter<Sample>,
380 sleep_duration: Duration,
381 sample_rate: SampleRate,
382 channel_count: ChannelCount,
383 source_is_active: Arc<AtomicBool>,
384}
385
386impl Replay {
387 pub fn source_is_active(&self) -> bool {
388 // - source could return None and not drop
389 // - source could be dropped before returning None
390 self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
391 }
392
393 /// Duration of what is in the buffer and can be returned without blocking.
394 pub fn duration_ready(&self) -> Duration {
395 let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
396
397 let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
398 Duration::from_secs_f64(seconds_queued)
399 }
400
401 /// Number of samples in the buffer and can be returned without blocking.
402 pub fn samples_ready(&self) -> usize {
403 self.rx.len() + self.buffer.len()
404 }
405}
406
407impl Iterator for Replay {
408 type Item = Sample;
409
410 fn next(&mut self) -> Option<Self::Item> {
411 if let Some(sample) = self.buffer.next() {
412 return Some(sample);
413 }
414
415 loop {
416 if let Some(new_buffer) = self.rx.pop() {
417 self.buffer = new_buffer.into_iter();
418 return self.buffer.next();
419 }
420
421 if !self.source_is_active() {
422 return None;
423 }
424
425 std::thread::sleep(self.sleep_duration);
426 }
427 }
428
429 fn size_hint(&self) -> (usize, Option<usize>) {
430 ((self.rx.len() + self.buffer.len()), None)
431 }
432}
433
434impl Source for Replay {
435 fn current_span_len(&self) -> Option<usize> {
436 None // source is not compatible with spans
437 }
438
439 fn channels(&self) -> ChannelCount {
440 self.channel_count
441 }
442
443 fn sample_rate(&self) -> SampleRate {
444 self.sample_rate
445 }
446
447 fn total_duration(&self) -> Option<Duration> {
448 None
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use rodio::{nz, static_buffer::StaticSamplesBuffer};
455
456 use super::*;
457
458 const SAMPLES: [Sample; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
459
460 fn test_source() -> StaticSamplesBuffer {
461 StaticSamplesBuffer::new(nz!(1), nz!(1), &SAMPLES)
462 }
463
464 mod process_buffer {
465 use super::*;
466
467 #[test]
468 fn callback_gets_all_samples() {
469 let input = test_source();
470
471 let _ = input
472 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
473 .count();
474 }
475 #[test]
476 fn callback_modifies_yielded() {
477 let input = test_source();
478
479 let yielded: Vec<_> = input
480 .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
481 for sample in buffer {
482 *sample += 1.0;
483 }
484 })
485 .collect();
486 assert_eq!(
487 yielded,
488 SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
489 )
490 }
491 #[test]
492 fn source_truncates_to_whole_buffers() {
493 let input = test_source();
494
495 let yielded = input
496 .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
497 .count();
498 assert_eq!(yielded, 3)
499 }
500 }
501
502 mod inspect_buffer {
503 use super::*;
504
505 #[test]
506 fn callback_gets_all_samples() {
507 let input = test_source();
508
509 let _ = input
510 .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
511 .count();
512 }
513 #[test]
514 fn source_does_not_truncate() {
515 let input = test_source();
516
517 let yielded = input
518 .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
519 .count();
520 assert_eq!(yielded, SAMPLES.len())
521 }
522 }
523
524 mod instant_replay {
525 use super::*;
526
527 #[test]
528 fn continues_after_history() {
529 let input = test_source();
530
531 let (mut replay, mut source) = input
532 .replayable(Duration::from_secs(3))
533 .expect("longer then 100ms");
534
535 source.by_ref().take(3).count();
536 let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
537 assert_eq!(&yielded, &SAMPLES[0..3],);
538
539 source.count();
540 let yielded: Vec<Sample> = replay.collect();
541 assert_eq!(&yielded, &SAMPLES[3..5],);
542 }
543
544 #[test]
545 fn keeps_only_latest() {
546 let input = test_source();
547
548 let (mut replay, mut source) = input
549 .replayable(Duration::from_secs(2))
550 .expect("longer then 100ms");
551
552 source.by_ref().take(5).count(); // get all items but do not end the source
553 let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
554 assert_eq!(&yielded, &SAMPLES[3..5]);
555 source.count(); // exhaust source
556 assert_eq!(replay.next(), None);
557 }
558
559 #[test]
560 fn keeps_correct_amount_of_seconds() {
561 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
562
563 let (replay, mut source) = input
564 .replayable(Duration::from_secs(2))
565 .expect("longer then 100ms");
566
567 // exhaust but do not yet end source
568 source.by_ref().take(40_000).count();
569
570 // take all samples we can without blocking
571 let ready = replay.samples_ready();
572 let n_yielded = replay.take_samples(ready).count();
573
574 let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
575 let margin = 16_000 / 10; // 100ms
576 assert!(n_yielded as u32 >= max - margin);
577 }
578
579 #[test]
580 fn samples_ready() {
581 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
582 let (mut replay, source) = input
583 .replayable(Duration::from_secs(2))
584 .expect("longer then 100ms");
585 assert_eq!(replay.by_ref().samples_ready(), 0);
586
587 source.take(8000).count(); // half a second
588 let margin = 16_000 / 10; // 100ms
589 let ready = replay.samples_ready();
590 assert!(ready >= 8000 - margin);
591 }
592 }
593}