1mod engine;
2
3use core::fmt;
4use std::{collections::VecDeque, sync::mpsc, thread};
5
6pub use engine::Engine;
7use rodio::{ChannelCount, Sample, SampleRate, Source, nz};
8
9use crate::engine::BLOCK_SHIFT;
10
11const SUPPORTED_SAMPLE_RATE: SampleRate = nz!(16_000);
12const SUPPORTED_CHANNEL_COUNT: ChannelCount = nz!(1);
13
14pub struct Denoiser<S: Source> {
15 inner: S,
16 input_tx: mpsc::Sender<[Sample; BLOCK_SHIFT]>,
17 denoised_rx: mpsc::Receiver<[Sample; BLOCK_SHIFT]>,
18 ready: [Sample; BLOCK_SHIFT],
19 next: usize,
20 state: IterState,
21 // When disabled instead of reading denoised sub-blocks from the engine through
22 // `denoised_rx` we read unprocessed from this queue. This maintains the same
23 // latency so we can 'trivially' re-enable
24 queued: Queue,
25}
26
27impl<S: Source> fmt::Debug for Denoiser<S> {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("Denoiser")
30 .field("state", &self.state)
31 .finish_non_exhaustive()
32 }
33}
34
35struct Queue(VecDeque<[Sample; BLOCK_SHIFT]>);
36
37impl Queue {
38 fn new() -> Self {
39 Self(VecDeque::new())
40 }
41 fn push(&mut self, block: [Sample; BLOCK_SHIFT]) {
42 self.0.push_back(block);
43 self.0.resize(4, [0f32; BLOCK_SHIFT]);
44 }
45 fn pop(&mut self) -> [Sample; BLOCK_SHIFT] {
46 debug_assert!(self.0.len() == 4);
47 self.0.pop_front().expect(
48 "There is no State where the queue is popped while there are less then 4 entries",
49 )
50 }
51}
52
53#[derive(Debug, Clone, Copy)]
54pub enum IterState {
55 Enabled,
56 StartingMidAudio { fed_to_denoiser: usize },
57 Disabled,
58 Startup { enabled: bool },
59}
60
61#[derive(Debug, thiserror::Error)]
62pub enum DenoiserError {
63 #[error("This denoiser only works on sources with samplerate 16000")]
64 UnsupportedSampleRate,
65 #[error("This denoiser only works on mono sources (1 channel)")]
66 UnsupportedChannelCount,
67}
68
69// todo dvdsk needs constant source upstream in rodio
70impl<S: Source> Denoiser<S> {
71 pub fn try_new(source: S) -> Result<Self, DenoiserError> {
72 if source.sample_rate() != SUPPORTED_SAMPLE_RATE {
73 return Err(DenoiserError::UnsupportedSampleRate);
74 }
75 if source.channels() != SUPPORTED_CHANNEL_COUNT {
76 return Err(DenoiserError::UnsupportedChannelCount);
77 }
78
79 let (input_tx, input_rx) = mpsc::channel();
80 let (denoised_tx, denoised_rx) = mpsc::channel();
81
82 thread::Builder::new()
83 .name("NeuralDenoiser".to_owned())
84 .spawn(move || {
85 run_neural_denoiser(denoised_tx, input_rx);
86 })
87 .expect("Should be ablet to spawn threads");
88
89 Ok(Self {
90 inner: source,
91 input_tx,
92 denoised_rx,
93 ready: [0.0; BLOCK_SHIFT],
94 state: IterState::Startup { enabled: true },
95 next: BLOCK_SHIFT,
96 queued: Queue::new(),
97 })
98 }
99
100 pub fn set_enabled(&mut self, enabled: bool) {
101 self.state = match (enabled, self.state) {
102 (false, IterState::StartingMidAudio { .. }) | (false, IterState::Enabled) => {
103 IterState::Disabled
104 }
105 (false, IterState::Startup { enabled: true }) => IterState::Startup { enabled: false },
106 (true, IterState::Disabled) => IterState::StartingMidAudio { fed_to_denoiser: 0 },
107 (_, state) => state,
108 };
109 }
110
111 fn feed(&self, sub_block: [f32; BLOCK_SHIFT]) {
112 self.input_tx.send(sub_block).unwrap();
113 }
114}
115
116fn run_neural_denoiser(
117 denoised_tx: mpsc::Sender<[f32; BLOCK_SHIFT]>,
118 input_rx: mpsc::Receiver<[f32; BLOCK_SHIFT]>,
119) {
120 let mut engine = Engine::new();
121 loop {
122 let Ok(sub_block) = input_rx.recv() else {
123 // tx must have dropped, stop thread
124 break;
125 };
126
127 let denoised_sub_block = engine.feed(&sub_block);
128 if denoised_tx.send(denoised_sub_block).is_err() {
129 break;
130 }
131 }
132}
133
134impl<S: Source> Source for Denoiser<S> {
135 fn current_span_len(&self) -> Option<usize> {
136 self.inner.current_span_len()
137 }
138
139 fn channels(&self) -> rodio::ChannelCount {
140 self.inner.channels()
141 }
142
143 fn sample_rate(&self) -> rodio::SampleRate {
144 self.inner.sample_rate()
145 }
146
147 fn total_duration(&self) -> Option<std::time::Duration> {
148 self.inner.total_duration()
149 }
150}
151
152impl<S: Source> Iterator for Denoiser<S> {
153 type Item = Sample;
154
155 #[inline]
156 fn next(&mut self) -> Option<Self::Item> {
157 self.next += 1;
158 if self.next < self.ready.len() {
159 let sample = self.ready[self.next];
160 return Some(sample);
161 }
162
163 // This is a separate function to prevent it from being inlined
164 // as this code only runs once every 128 samples
165 self.prepare_next_ready()
166 .inspect_err(|_| {
167 log::error!("Denoise engine crashed");
168 })
169 .ok()
170 .flatten()
171 }
172}
173
174#[derive(Debug, thiserror::Error)]
175#[error("Could not send or receive from denoise thread. It must have crashed")]
176struct DenoiseEngineCrashed;
177
178impl<S: Source> Denoiser<S> {
179 #[cold]
180 fn prepare_next_ready(&mut self) -> Result<Option<f32>, DenoiseEngineCrashed> {
181 self.state = match self.state {
182 IterState::Startup { enabled } => {
183 // guaranteed to be coming from silence
184 for _ in 0..3 {
185 let Some(sub_block) = read_sub_block(&mut self.inner) else {
186 return Ok(None);
187 };
188 self.queued.push(sub_block);
189 self.input_tx
190 .send(sub_block)
191 .map_err(|_| DenoiseEngineCrashed)?;
192 }
193 let Some(sub_block) = read_sub_block(&mut self.inner) else {
194 return Ok(None);
195 };
196 self.queued.push(sub_block);
197 self.input_tx
198 .send(sub_block)
199 .map_err(|_| DenoiseEngineCrashed)?;
200 // throw out old blocks that are denoised silence
201 let _ = self.denoised_rx.iter().take(3).count();
202 self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
203
204 let Some(sub_block) = read_sub_block(&mut self.inner) else {
205 return Ok(None);
206 };
207 self.queued.push(sub_block);
208 self.feed(sub_block);
209
210 if enabled {
211 IterState::Enabled
212 } else {
213 IterState::Disabled
214 }
215 }
216 IterState::Enabled => {
217 self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
218 let Some(sub_block) = read_sub_block(&mut self.inner) else {
219 return Ok(None);
220 };
221 self.queued.push(sub_block);
222 self.input_tx
223 .send(sub_block)
224 .map_err(|_| DenoiseEngineCrashed)?;
225 IterState::Enabled
226 }
227 IterState::Disabled => {
228 // Need to maintain the same 512 samples delay such that
229 // we can re-enable at any point.
230 self.ready = self.queued.pop();
231 let Some(sub_block) = read_sub_block(&mut self.inner) else {
232 return Ok(None);
233 };
234 self.queued.push(sub_block);
235 IterState::Disabled
236 }
237 IterState::StartingMidAudio {
238 fed_to_denoiser: mut sub_blocks_fed,
239 } => {
240 self.ready = self.queued.pop();
241 let Some(sub_block) = read_sub_block(&mut self.inner) else {
242 return Ok(None);
243 };
244 self.queued.push(sub_block);
245 self.input_tx
246 .send(sub_block)
247 .map_err(|_| DenoiseEngineCrashed)?;
248 sub_blocks_fed += 1;
249 if sub_blocks_fed > 4 {
250 // throw out partially denoised blocks,
251 // next will be correctly denoised
252 let _ = self.denoised_rx.iter().take(3).count();
253 IterState::Enabled
254 } else {
255 IterState::StartingMidAudio {
256 fed_to_denoiser: sub_blocks_fed,
257 }
258 }
259 }
260 };
261
262 self.next = 0;
263 Ok(Some(self.ready[0]))
264 }
265}
266
267fn read_sub_block(s: &mut impl Source) -> Option<[f32; BLOCK_SHIFT]> {
268 let mut res = [0f32; BLOCK_SHIFT];
269 for sample in &mut res {
270 *sample = s.next()?;
271 }
272 Some(res)
273}