1mod engine;
2
3use core::fmt;
4use std::{collections::VecDeque, sync::mpsc, thread};
5
6pub use engine::Engine;
7use rodio::{ChannelCount, Sample, SampleRate, Source, nz};
8
9use crate::engine::BLOCK_SHIFT;
10
11const SUPPORTED_SAMPLE_RATE: SampleRate = nz!(16_000);
12const SUPPORTED_CHANNEL_COUNT: ChannelCount = nz!(1);
13
14pub struct Denoiser<S: Source> {
15 inner: S,
16 input_tx: mpsc::Sender<[Sample; BLOCK_SHIFT]>,
17 denoised_rx: mpsc::Receiver<[Sample; BLOCK_SHIFT]>,
18 ready: [Sample; BLOCK_SHIFT],
19 next: usize,
20 state: IterState,
21 // When disabled instead of reading denoised sub-blocks from the engine through
22 // `denoised_rx` we read unprocessed from this queue. This maintains the same
23 // latency so we can 'trivially' re-enable
24 queued: Queue,
25}
26
27impl<S: Source> fmt::Debug for Denoiser<S> {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("Denoiser")
30 .field("state", &self.state)
31 .finish_non_exhaustive()
32 }
33}
34
35struct Queue(VecDeque<[Sample; BLOCK_SHIFT]>);
36
37impl Queue {
38 fn new() -> Self {
39 Self(VecDeque::new())
40 }
41 fn push(&mut self, block: [Sample; BLOCK_SHIFT]) {
42 self.0.push_back(block);
43 self.0.resize(4, [0f32; BLOCK_SHIFT]);
44 }
45 fn pop(&mut self) -> [Sample; BLOCK_SHIFT] {
46 debug_assert!(self.0.len() == 4);
47 self.0.pop_front().expect(
48 "There is no State where the queue is popped while there are less then 4 entries",
49 )
50 }
51}
52
53#[derive(Debug, Clone, Copy)]
54pub enum IterState {
55 Enabled,
56 StartingMidAudio { fed_to_denoiser: usize },
57 Disabled,
58 Startup { enabled: bool },
59}
60
61#[derive(Debug, thiserror::Error)]
62pub enum DenoiserError {
63 #[error("This denoiser only works on sources with samplerate 16000")]
64 UnsupportedSampleRate,
65 #[error("This denoiser only works on mono sources (1 channel)")]
66 UnsupportedChannelCount,
67}
68
69// todo dvdsk needs constant source upstream in rodio
70impl<S: Source> Denoiser<S> {
71 pub fn try_new(source: S) -> Result<Self, DenoiserError> {
72 if source.sample_rate() != SUPPORTED_SAMPLE_RATE {
73 return Err(DenoiserError::UnsupportedSampleRate);
74 }
75 if source.channels() != SUPPORTED_CHANNEL_COUNT {
76 return Err(DenoiserError::UnsupportedChannelCount);
77 }
78
79 let (input_tx, input_rx) = mpsc::channel();
80 let (denoised_tx, denoised_rx) = mpsc::channel();
81
82 thread::spawn(move || {
83 run_neural_denoiser(denoised_tx, input_rx);
84 });
85
86 Ok(Self {
87 inner: source,
88 input_tx,
89 denoised_rx,
90 ready: [0.0; BLOCK_SHIFT],
91 state: IterState::Startup { enabled: true },
92 next: BLOCK_SHIFT,
93 queued: Queue::new(),
94 })
95 }
96
97 pub fn set_enabled(&mut self, enabled: bool) {
98 self.state = match (enabled, self.state) {
99 (false, IterState::StartingMidAudio { .. }) | (false, IterState::Enabled) => {
100 IterState::Disabled
101 }
102 (false, IterState::Startup { enabled: true }) => IterState::Startup { enabled: false },
103 (true, IterState::Disabled) => IterState::StartingMidAudio { fed_to_denoiser: 0 },
104 (_, state) => state,
105 };
106 }
107
108 fn feed(&self, sub_block: [f32; BLOCK_SHIFT]) {
109 self.input_tx.send(sub_block).unwrap();
110 }
111}
112
113fn run_neural_denoiser(
114 denoised_tx: mpsc::Sender<[f32; BLOCK_SHIFT]>,
115 input_rx: mpsc::Receiver<[f32; BLOCK_SHIFT]>,
116) {
117 let mut engine = Engine::new();
118 loop {
119 let Ok(sub_block) = input_rx.recv() else {
120 // tx must have dropped, stop thread
121 break;
122 };
123
124 let denoised_sub_block = engine.feed(&sub_block);
125 if denoised_tx.send(denoised_sub_block).is_err() {
126 break;
127 }
128 }
129}
130
131impl<S: Source> Source for Denoiser<S> {
132 fn current_span_len(&self) -> Option<usize> {
133 self.inner.current_span_len()
134 }
135
136 fn channels(&self) -> rodio::ChannelCount {
137 self.inner.channels()
138 }
139
140 fn sample_rate(&self) -> rodio::SampleRate {
141 self.inner.sample_rate()
142 }
143
144 fn total_duration(&self) -> Option<std::time::Duration> {
145 self.inner.total_duration()
146 }
147}
148
149impl<S: Source> Iterator for Denoiser<S> {
150 type Item = Sample;
151
152 #[inline]
153 fn next(&mut self) -> Option<Self::Item> {
154 self.next += 1;
155 if self.next < self.ready.len() {
156 let sample = self.ready[self.next];
157 return Some(sample);
158 }
159
160 // This is a separate function to prevent it from being inlined
161 // as this code only runs once every 128 samples
162 self.prepare_next_ready()
163 .inspect_err(|_| {
164 log::error!("Denoise engine crashed");
165 })
166 .ok()
167 .flatten()
168 }
169}
170
171#[derive(Debug, thiserror::Error)]
172#[error("Could not send or receive from denoise thread. It must have crashed")]
173struct DenoiseEngineCrashed;
174
175impl<S: Source> Denoiser<S> {
176 #[cold]
177 fn prepare_next_ready(&mut self) -> Result<Option<f32>, DenoiseEngineCrashed> {
178 self.state = match self.state {
179 IterState::Startup { enabled } => {
180 // guaranteed to be coming from silence
181 for _ in 0..3 {
182 let Some(sub_block) = read_sub_block(&mut self.inner) else {
183 return Ok(None);
184 };
185 self.queued.push(sub_block);
186 self.input_tx
187 .send(sub_block)
188 .map_err(|_| DenoiseEngineCrashed)?;
189 }
190 let Some(sub_block) = read_sub_block(&mut self.inner) else {
191 return Ok(None);
192 };
193 self.queued.push(sub_block);
194 self.input_tx
195 .send(sub_block)
196 .map_err(|_| DenoiseEngineCrashed)?;
197 // throw out old blocks that are denoised silence
198 let _ = self.denoised_rx.iter().take(3).count();
199 self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
200
201 let Some(sub_block) = read_sub_block(&mut self.inner) else {
202 return Ok(None);
203 };
204 self.queued.push(sub_block);
205 self.feed(sub_block);
206
207 if enabled {
208 IterState::Enabled
209 } else {
210 IterState::Disabled
211 }
212 }
213 IterState::Enabled => {
214 self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
215 let Some(sub_block) = read_sub_block(&mut self.inner) else {
216 return Ok(None);
217 };
218 self.queued.push(sub_block);
219 self.input_tx
220 .send(sub_block)
221 .map_err(|_| DenoiseEngineCrashed)?;
222 IterState::Enabled
223 }
224 IterState::Disabled => {
225 // Need to maintain the same 512 samples delay such that
226 // we can re-enable at any point.
227 self.ready = self.queued.pop();
228 let Some(sub_block) = read_sub_block(&mut self.inner) else {
229 return Ok(None);
230 };
231 self.queued.push(sub_block);
232 IterState::Disabled
233 }
234 IterState::StartingMidAudio {
235 fed_to_denoiser: mut sub_blocks_fed,
236 } => {
237 self.ready = self.queued.pop();
238 let Some(sub_block) = read_sub_block(&mut self.inner) else {
239 return Ok(None);
240 };
241 self.queued.push(sub_block);
242 self.input_tx
243 .send(sub_block)
244 .map_err(|_| DenoiseEngineCrashed)?;
245 sub_blocks_fed += 1;
246 if sub_blocks_fed > 4 {
247 // throw out partially denoised blocks,
248 // next will be correctly denoised
249 let _ = self.denoised_rx.iter().take(3).count();
250 IterState::Enabled
251 } else {
252 IterState::StartingMidAudio {
253 fed_to_denoiser: sub_blocks_fed,
254 }
255 }
256 }
257 };
258
259 self.next = 0;
260 Ok(Some(self.ready[0]))
261 }
262}
263
264fn read_sub_block(s: &mut impl Source) -> Option<[f32; BLOCK_SHIFT]> {
265 let mut res = [0f32; BLOCK_SHIFT];
266 for sample in &mut res {
267 *sample = s.next()?;
268 }
269 Some(res)
270}