1use std::{
2 sync::{
3 Arc, Mutex,
4 atomic::{AtomicBool, Ordering},
5 },
6 time::Duration,
7};
8
9use crossbeam::queue::ArrayQueue;
10use rodio::{ChannelCount, Sample, SampleRate, Source};
11
12#[derive(Debug, thiserror::Error)]
13#[error("Replay duration is too short must be >= 100ms")]
14pub struct ReplayDurationTooShort;
15
16pub fn replayable<S: Source>(
17 source: S,
18 duration: Duration,
19) -> Result<(Replay, Replayable<S>), ReplayDurationTooShort> {
20 if duration < Duration::from_millis(100) {
21 return Err(ReplayDurationTooShort);
22 }
23
24 let samples_per_second = source.sample_rate().get() as usize * source.channels().get() as usize;
25 let samples_to_queue = duration.as_secs_f64() * samples_per_second as f64;
26 let samples_to_queue =
27 (samples_to_queue as usize).next_multiple_of(source.channels().get().into());
28
29 let chunk_size =
30 (samples_per_second.div_ceil(10)).next_multiple_of(source.channels().get() as usize);
31 let chunks_to_queue = samples_to_queue.div_ceil(chunk_size);
32
33 let is_active = Arc::new(AtomicBool::new(true));
34 let queue = Arc::new(ReplayQueue::new(chunks_to_queue, chunk_size));
35 Ok((
36 Replay {
37 rx: Arc::clone(&queue),
38 buffer: Vec::new().into_iter(),
39 sleep_duration: duration / 2,
40 sample_rate: source.sample_rate(),
41 channel_count: source.channels(),
42 source_is_active: is_active.clone(),
43 },
44 Replayable {
45 tx: queue,
46 inner: source,
47 buffer: Vec::with_capacity(chunk_size),
48 chunk_size,
49 is_active,
50 },
51 ))
52}
53
54/// constant source, only works on a single span
55#[derive(Debug)]
56struct ReplayQueue {
57 inner: ArrayQueue<Vec<Sample>>,
58 normal_chunk_len: usize,
59 /// The last chunk in the queue may be smaller than
60 /// the normal chunk size. This is always equal to the
61 /// size of the last element in the queue.
62 /// (so normally chunk_size)
63 last_chunk: Mutex<Vec<Sample>>,
64}
65
66impl ReplayQueue {
67 fn new(queue_len: usize, chunk_size: usize) -> Self {
68 Self {
69 inner: ArrayQueue::new(queue_len),
70 normal_chunk_len: chunk_size,
71 last_chunk: Mutex::new(Vec::new()),
72 }
73 }
74 /// Returns the length in samples
75 fn len(&self) -> usize {
76 self.inner.len().saturating_sub(1) * self.normal_chunk_len
77 + self
78 .last_chunk
79 .lock()
80 .expect("Self::push_last can not poison this lock")
81 .len()
82 }
83
84 fn pop(&self) -> Option<Vec<Sample>> {
85 self.inner.pop() // removes element that was inserted first
86 }
87
88 fn push_last(&self, mut samples: Vec<Sample>) {
89 let mut last_chunk = self
90 .last_chunk
91 .lock()
92 .expect("Self::len can not poison this lock");
93 std::mem::swap(&mut *last_chunk, &mut samples);
94 }
95
96 fn push_normal(&self, samples: Vec<Sample>) {
97 let _pushed_out_of_ringbuf = self.inner.force_push(samples);
98 }
99}
100
101/// constant source, only works on a single span
102#[derive(Debug)]
103pub struct Replayable<S: Source> {
104 inner: S,
105 buffer: Vec<Sample>,
106 chunk_size: usize,
107 tx: Arc<ReplayQueue>,
108 is_active: Arc<AtomicBool>,
109}
110
111impl<S: Source> Iterator for Replayable<S> {
112 type Item = Sample;
113
114 fn next(&mut self) -> Option<Self::Item> {
115 if let Some(sample) = self.inner.next() {
116 self.buffer.push(sample);
117 // If the buffer is full send it
118 if self.buffer.len() == self.chunk_size {
119 self.tx.push_normal(std::mem::take(&mut self.buffer));
120 }
121 Some(sample)
122 } else {
123 let last_chunk = std::mem::take(&mut self.buffer);
124 self.tx.push_last(last_chunk);
125 self.is_active.store(false, Ordering::Relaxed);
126 None
127 }
128 }
129
130 fn size_hint(&self) -> (usize, Option<usize>) {
131 self.inner.size_hint()
132 }
133}
134
135impl<S: Source> Source for Replayable<S> {
136 fn current_span_len(&self) -> Option<usize> {
137 self.inner.current_span_len()
138 }
139
140 fn channels(&self) -> ChannelCount {
141 self.inner.channels()
142 }
143
144 fn sample_rate(&self) -> SampleRate {
145 self.inner.sample_rate()
146 }
147
148 fn total_duration(&self) -> Option<Duration> {
149 self.inner.total_duration()
150 }
151}
152
153/// constant source, only works on a single span
154#[derive(Debug)]
155pub struct Replay {
156 rx: Arc<ReplayQueue>,
157 buffer: std::vec::IntoIter<Sample>,
158 sleep_duration: Duration,
159 sample_rate: SampleRate,
160 channel_count: ChannelCount,
161 source_is_active: Arc<AtomicBool>,
162}
163
164impl Replay {
165 pub fn source_is_active(&self) -> bool {
166 // - source could return None and not drop
167 // - source could be dropped before returning None
168 self.source_is_active.load(Ordering::Relaxed) && Arc::strong_count(&self.rx) < 2
169 }
170
171 /// Duration of what is in the buffer and can be returned without blocking.
172 pub fn duration_ready(&self) -> Duration {
173 let samples_per_second = self.channels().get() as u32 * self.sample_rate().get();
174
175 let seconds_queued = self.samples_ready() as f64 / samples_per_second as f64;
176 Duration::from_secs_f64(seconds_queued)
177 }
178
179 /// Number of samples in the buffer and can be returned without blocking.
180 pub fn samples_ready(&self) -> usize {
181 self.rx.len() + self.buffer.len()
182 }
183}
184
185impl Iterator for Replay {
186 type Item = Sample;
187
188 fn next(&mut self) -> Option<Self::Item> {
189 if let Some(sample) = self.buffer.next() {
190 return Some(sample);
191 }
192
193 loop {
194 if let Some(new_buffer) = self.rx.pop() {
195 self.buffer = new_buffer.into_iter();
196 return self.buffer.next();
197 }
198
199 if !self.source_is_active() {
200 return None;
201 }
202
203 // The queue does not support blocking on a next item. We want this queue as it
204 // is quite fast and provides a fixed size. We know how many samples are in a
205 // buffer so if we do not get one now we must be getting one after `sleep_duration`.
206 std::thread::sleep(self.sleep_duration);
207 }
208 }
209
210 fn size_hint(&self) -> (usize, Option<usize>) {
211 ((self.rx.len() + self.buffer.len()), None)
212 }
213}
214
215impl Source for Replay {
216 fn current_span_len(&self) -> Option<usize> {
217 None // source is not compatible with spans
218 }
219
220 fn channels(&self) -> ChannelCount {
221 self.channel_count
222 }
223
224 fn sample_rate(&self) -> SampleRate {
225 self.sample_rate
226 }
227
228 fn total_duration(&self) -> Option<Duration> {
229 None
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use rodio::{nz, static_buffer::StaticSamplesBuffer};
236
237 use super::*;
238 use crate::{
239 RodioExt,
240 rodio_ext::tests::{SAMPLES, test_source},
241 };
242
243 #[test]
244 fn continues_after_history() {
245 let input = test_source();
246
247 let (mut replay, mut source) = input
248 .replayable(Duration::from_secs(3))
249 .expect("longer than 100ms");
250
251 source.by_ref().take(3).count();
252 let yielded: Vec<Sample> = replay.by_ref().take(3).collect();
253 assert_eq!(&yielded, &SAMPLES[0..3],);
254
255 source.count();
256 let yielded: Vec<Sample> = replay.collect();
257 assert_eq!(&yielded, &SAMPLES[3..5],);
258 }
259
260 #[test]
261 fn keeps_only_latest() {
262 let input = test_source();
263
264 let (mut replay, mut source) = input
265 .replayable(Duration::from_secs(2))
266 .expect("longer than 100ms");
267
268 source.by_ref().take(5).count(); // get all items but do not end the source
269 let yielded: Vec<Sample> = replay.by_ref().take(2).collect();
270 assert_eq!(&yielded, &SAMPLES[3..5]);
271 source.count(); // exhaust source
272 assert_eq!(replay.next(), None);
273 }
274
275 #[test]
276 fn keeps_correct_amount_of_seconds() {
277 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
278
279 let (replay, mut source) = input
280 .replayable(Duration::from_secs(2))
281 .expect("longer than 100ms");
282
283 // exhaust but do not yet end source
284 source.by_ref().take(40_000).count();
285
286 // take all samples we can without blocking
287 let ready = replay.samples_ready();
288 let n_yielded = replay.take_samples(ready).count();
289
290 let max = source.sample_rate().get() * source.channels().get() as u32 * 2;
291 let margin = 16_000 / 10; // 100ms
292 assert!(n_yielded as u32 >= max - margin);
293 }
294
295 #[test]
296 fn samples_ready() {
297 let input = StaticSamplesBuffer::new(nz!(1), nz!(16_000), &[0.0; 40_000]);
298 let (mut replay, source) = input
299 .replayable(Duration::from_secs(2))
300 .expect("longer than 100ms");
301 assert_eq!(replay.by_ref().samples_ready(), 0);
302
303 source.take(8000).count(); // half a second
304 let margin = 16_000 / 10; // 100ms
305 let ready = replay.samples_ready();
306 assert!(ready >= 8000 - margin);
307 }
308}