lib.rs

  1mod engine;
  2
  3use core::fmt;
  4use std::{collections::VecDeque, sync::mpsc, thread};
  5
  6pub use engine::Engine;
  7use rodio::{ChannelCount, Sample, SampleRate, Source, nz};
  8
  9use crate::engine::BLOCK_SHIFT;
 10
 11const SUPPORTED_SAMPLE_RATE: SampleRate = nz!(16_000);
 12const SUPPORTED_CHANNEL_COUNT: ChannelCount = nz!(1);
 13
 14pub struct Denoiser<S: Source> {
 15    inner: S,
 16    input_tx: mpsc::Sender<[Sample; BLOCK_SHIFT]>,
 17    denoised_rx: mpsc::Receiver<[Sample; BLOCK_SHIFT]>,
 18    ready: [Sample; BLOCK_SHIFT],
 19    next: usize,
 20    state: IterState,
 21    // When disabled instead of reading denoised sub-blocks from the engine through
 22    // `denoised_rx` we read unprocessed from this queue. This maintains the same
 23    // latency so we can 'trivially' re-enable
 24    queued: Queue,
 25}
 26
 27impl<S: Source> fmt::Debug for Denoiser<S> {
 28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 29        f.debug_struct("Denoiser")
 30            .field("state", &self.state)
 31            .finish_non_exhaustive()
 32    }
 33}
 34
 35struct Queue(VecDeque<[Sample; BLOCK_SHIFT]>);
 36
 37impl Queue {
 38    fn new() -> Self {
 39        Self(VecDeque::new())
 40    }
 41    fn push(&mut self, block: [Sample; BLOCK_SHIFT]) {
 42        self.0.push_back(block);
 43        self.0.resize(4, [0f32; BLOCK_SHIFT]);
 44    }
 45    fn pop(&mut self) -> [Sample; BLOCK_SHIFT] {
 46        debug_assert!(self.0.len() == 4);
 47        self.0.pop_front().expect(
 48            "There is no State where the queue is popped while there are less then 4 entries",
 49        )
 50    }
 51}
 52
 53#[derive(Debug, Clone, Copy)]
 54pub enum IterState {
 55    Enabled,
 56    StartingMidAudio { fed_to_denoiser: usize },
 57    Disabled,
 58    Startup { enabled: bool },
 59}
 60
 61#[derive(Debug, thiserror::Error)]
 62pub enum DenoiserError {
 63    #[error("This denoiser only works on sources with samplerate 16000")]
 64    UnsupportedSampleRate,
 65    #[error("This denoiser only works on mono sources (1 channel)")]
 66    UnsupportedChannelCount,
 67}
 68
 69// todo dvdsk needs constant source upstream in rodio
 70impl<S: Source> Denoiser<S> {
 71    pub fn try_new(source: S) -> Result<Self, DenoiserError> {
 72        if source.sample_rate() != SUPPORTED_SAMPLE_RATE {
 73            return Err(DenoiserError::UnsupportedSampleRate);
 74        }
 75        if source.channels() != SUPPORTED_CHANNEL_COUNT {
 76            return Err(DenoiserError::UnsupportedChannelCount);
 77        }
 78
 79        let (input_tx, input_rx) = mpsc::channel();
 80        let (denoised_tx, denoised_rx) = mpsc::channel();
 81
 82        thread::Builder::new()
 83            .name("NeuralDenoiser".to_owned())
 84            .spawn(move || {
 85                run_neural_denoiser(denoised_tx, input_rx);
 86            })
 87            .unwrap();
 88
 89        Ok(Self {
 90            inner: source,
 91            input_tx,
 92            denoised_rx,
 93            ready: [0.0; BLOCK_SHIFT],
 94            state: IterState::Startup { enabled: true },
 95            next: BLOCK_SHIFT,
 96            queued: Queue::new(),
 97        })
 98    }
 99
100    pub fn set_enabled(&mut self, enabled: bool) {
101        self.state = match (enabled, self.state) {
102            (false, IterState::StartingMidAudio { .. }) | (false, IterState::Enabled) => {
103                IterState::Disabled
104            }
105            (false, IterState::Startup { enabled: true }) => IterState::Startup { enabled: false },
106            (true, IterState::Disabled) => IterState::StartingMidAudio { fed_to_denoiser: 0 },
107            (_, state) => state,
108        };
109    }
110
111    fn feed(&self, sub_block: [f32; BLOCK_SHIFT]) {
112        self.input_tx.send(sub_block).unwrap();
113    }
114}
115
116fn run_neural_denoiser(
117    denoised_tx: mpsc::Sender<[f32; BLOCK_SHIFT]>,
118    input_rx: mpsc::Receiver<[f32; BLOCK_SHIFT]>,
119) {
120    let mut engine = Engine::new();
121    loop {
122        let Ok(sub_block) = input_rx.recv() else {
123            // tx must have dropped, stop thread
124            break;
125        };
126
127        let denoised_sub_block = engine.feed(&sub_block);
128        if denoised_tx.send(denoised_sub_block).is_err() {
129            break;
130        }
131    }
132}
133
134impl<S: Source> Source for Denoiser<S> {
135    fn current_span_len(&self) -> Option<usize> {
136        self.inner.current_span_len()
137    }
138
139    fn channels(&self) -> rodio::ChannelCount {
140        self.inner.channels()
141    }
142
143    fn sample_rate(&self) -> rodio::SampleRate {
144        self.inner.sample_rate()
145    }
146
147    fn total_duration(&self) -> Option<std::time::Duration> {
148        self.inner.total_duration()
149    }
150}
151
152impl<S: Source> Iterator for Denoiser<S> {
153    type Item = Sample;
154
155    #[inline]
156    fn next(&mut self) -> Option<Self::Item> {
157        self.next += 1;
158        if self.next < self.ready.len() {
159            let sample = self.ready[self.next];
160            return Some(sample);
161        }
162
163        // This is a separate function to prevent it from being inlined
164        // as this code only runs once every 128 samples
165        self.prepare_next_ready()
166            .inspect_err(|_| {
167                log::error!("Denoise engine crashed");
168            })
169            .ok()
170            .flatten()
171    }
172}
173
174#[derive(Debug, thiserror::Error)]
175#[error("Could not send or receive from denoise thread. It must have crashed")]
176struct DenoiseEngineCrashed;
177
178impl<S: Source> Denoiser<S> {
179    #[cold]
180    fn prepare_next_ready(&mut self) -> Result<Option<f32>, DenoiseEngineCrashed> {
181        self.state = match self.state {
182            IterState::Startup { enabled } => {
183                // guaranteed to be coming from silence
184                for _ in 0..3 {
185                    let Some(sub_block) = read_sub_block(&mut self.inner) else {
186                        return Ok(None);
187                    };
188                    self.queued.push(sub_block);
189                    self.input_tx
190                        .send(sub_block)
191                        .map_err(|_| DenoiseEngineCrashed)?;
192                }
193                let Some(sub_block) = read_sub_block(&mut self.inner) else {
194                    return Ok(None);
195                };
196                self.queued.push(sub_block);
197                self.input_tx
198                    .send(sub_block)
199                    .map_err(|_| DenoiseEngineCrashed)?;
200                // throw out old blocks that are denoised silence
201                let _ = self.denoised_rx.iter().take(3).count();
202                self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
203
204                let Some(sub_block) = read_sub_block(&mut self.inner) else {
205                    return Ok(None);
206                };
207                self.queued.push(sub_block);
208                self.feed(sub_block);
209
210                if enabled {
211                    IterState::Enabled
212                } else {
213                    IterState::Disabled
214                }
215            }
216            IterState::Enabled => {
217                self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
218                let Some(sub_block) = read_sub_block(&mut self.inner) else {
219                    return Ok(None);
220                };
221                self.queued.push(sub_block);
222                self.input_tx
223                    .send(sub_block)
224                    .map_err(|_| DenoiseEngineCrashed)?;
225                IterState::Enabled
226            }
227            IterState::Disabled => {
228                // Need to maintain the same 512 samples delay such that
229                // we can re-enable at any point.
230                self.ready = self.queued.pop();
231                let Some(sub_block) = read_sub_block(&mut self.inner) else {
232                    return Ok(None);
233                };
234                self.queued.push(sub_block);
235                IterState::Disabled
236            }
237            IterState::StartingMidAudio {
238                fed_to_denoiser: mut sub_blocks_fed,
239            } => {
240                self.ready = self.queued.pop();
241                let Some(sub_block) = read_sub_block(&mut self.inner) else {
242                    return Ok(None);
243                };
244                self.queued.push(sub_block);
245                self.input_tx
246                    .send(sub_block)
247                    .map_err(|_| DenoiseEngineCrashed)?;
248                sub_blocks_fed += 1;
249                if sub_blocks_fed > 4 {
250                    // throw out partially denoised blocks,
251                    // next will be correctly denoised
252                    let _ = self.denoised_rx.iter().take(3).count();
253                    IterState::Enabled
254                } else {
255                    IterState::StartingMidAudio {
256                        fed_to_denoiser: sub_blocks_fed,
257                    }
258                }
259            }
260        };
261
262        self.next = 0;
263        Ok(Some(self.ready[0]))
264    }
265}
266
267fn read_sub_block(s: &mut impl Source) -> Option<[f32; BLOCK_SHIFT]> {
268    let mut res = [0f32; BLOCK_SHIFT];
269    for sample in &mut res {
270        *sample = s.next()?;
271    }
272    Some(res)
273}