lib.rs

  1mod engine;
  2
  3use core::fmt;
  4use std::{collections::VecDeque, sync::mpsc, thread};
  5
  6pub use engine::Engine;
  7use rodio::{ChannelCount, Sample, SampleRate, Source, nz};
  8
  9use crate::engine::BLOCK_SHIFT;
 10
 11const SUPPORTED_SAMPLE_RATE: SampleRate = nz!(16_000);
 12const SUPPORTED_CHANNEL_COUNT: ChannelCount = nz!(1);
 13
 14pub struct Denoiser<S: Source> {
 15    inner: S,
 16    input_tx: mpsc::Sender<[Sample; BLOCK_SHIFT]>,
 17    denoised_rx: mpsc::Receiver<[Sample; BLOCK_SHIFT]>,
 18    ready: [Sample; BLOCK_SHIFT],
 19    next: usize,
 20    state: IterState,
 21    // When disabled instead of reading denoised sub-blocks from the engine through
 22    // `denoised_rx` we read unprocessed from this queue. This maintains the same
 23    // latency so we can 'trivially' re-enable
 24    queued: Queue,
 25}
 26
 27impl<S: Source> fmt::Debug for Denoiser<S> {
 28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 29        f.debug_struct("Denoiser")
 30            .field("state", &self.state)
 31            .finish_non_exhaustive()
 32    }
 33}
 34
 35struct Queue(VecDeque<[Sample; BLOCK_SHIFT]>);
 36
 37impl Queue {
 38    fn new() -> Self {
 39        Self(VecDeque::new())
 40    }
 41    fn push(&mut self, block: [Sample; BLOCK_SHIFT]) {
 42        self.0.push_back(block);
 43        self.0.resize(4, [0f32; BLOCK_SHIFT]);
 44    }
 45    fn pop(&mut self) -> [Sample; BLOCK_SHIFT] {
 46        debug_assert!(self.0.len() == 4);
 47        self.0.pop_front().expect(
 48            "There is no State where the queue is popped while there are less then 4 entries",
 49        )
 50    }
 51}
 52
 53#[derive(Debug, Clone, Copy)]
 54pub enum IterState {
 55    Enabled,
 56    StartingMidAudio { fed_to_denoiser: usize },
 57    Disabled,
 58    Startup { enabled: bool },
 59}
 60
 61#[derive(Debug, thiserror::Error)]
 62pub enum DenoiserError {
 63    #[error("This denoiser only works on sources with samplerate 16000")]
 64    UnsupportedSampleRate,
 65    #[error("This denoiser only works on mono sources (1 channel)")]
 66    UnsupportedChannelCount,
 67}
 68
 69// todo dvdsk needs constant source upstream in rodio
 70impl<S: Source> Denoiser<S> {
 71    pub fn try_new(source: S) -> Result<Self, DenoiserError> {
 72        if source.sample_rate() != SUPPORTED_SAMPLE_RATE {
 73            return Err(DenoiserError::UnsupportedSampleRate);
 74        }
 75        if source.channels() != SUPPORTED_CHANNEL_COUNT {
 76            return Err(DenoiserError::UnsupportedChannelCount);
 77        }
 78
 79        let (input_tx, input_rx) = mpsc::channel();
 80        let (denoised_tx, denoised_rx) = mpsc::channel();
 81
 82        thread::spawn(move || {
 83            run_neural_denoiser(denoised_tx, input_rx);
 84        });
 85
 86        Ok(Self {
 87            inner: source,
 88            input_tx,
 89            denoised_rx,
 90            ready: [0.0; BLOCK_SHIFT],
 91            state: IterState::Startup { enabled: true },
 92            next: BLOCK_SHIFT,
 93            queued: Queue::new(),
 94        })
 95    }
 96
 97    pub fn set_enabled(&mut self, enabled: bool) {
 98        self.state = match (enabled, self.state) {
 99            (false, IterState::StartingMidAudio { .. }) | (false, IterState::Enabled) => {
100                IterState::Disabled
101            }
102            (false, IterState::Startup { enabled: true }) => IterState::Startup { enabled: false },
103            (true, IterState::Disabled) => IterState::StartingMidAudio { fed_to_denoiser: 0 },
104            (_, state) => state,
105        };
106    }
107
108    fn feed(&self, sub_block: [f32; BLOCK_SHIFT]) {
109        self.input_tx.send(sub_block).unwrap();
110    }
111}
112
113fn run_neural_denoiser(
114    denoised_tx: mpsc::Sender<[f32; BLOCK_SHIFT]>,
115    input_rx: mpsc::Receiver<[f32; BLOCK_SHIFT]>,
116) {
117    let mut engine = Engine::new();
118    loop {
119        let Ok(sub_block) = input_rx.recv() else {
120            // tx must have dropped, stop thread
121            break;
122        };
123
124        let denoised_sub_block = engine.feed(&sub_block);
125        if denoised_tx.send(denoised_sub_block).is_err() {
126            break;
127        }
128    }
129}
130
131impl<S: Source> Source for Denoiser<S> {
132    fn current_span_len(&self) -> Option<usize> {
133        self.inner.current_span_len()
134    }
135
136    fn channels(&self) -> rodio::ChannelCount {
137        self.inner.channels()
138    }
139
140    fn sample_rate(&self) -> rodio::SampleRate {
141        self.inner.sample_rate()
142    }
143
144    fn total_duration(&self) -> Option<std::time::Duration> {
145        self.inner.total_duration()
146    }
147}
148
149impl<S: Source> Iterator for Denoiser<S> {
150    type Item = Sample;
151
152    #[inline]
153    fn next(&mut self) -> Option<Self::Item> {
154        self.next += 1;
155        if self.next < self.ready.len() {
156            let sample = self.ready[self.next];
157            return Some(sample);
158        }
159
160        // This is a separate function to prevent it from being inlined
161        // as this code only runs once every 128 samples
162        self.prepare_next_ready()
163            .inspect_err(|_| {
164                log::error!("Denoise engine crashed");
165            })
166            .ok()
167            .flatten()
168    }
169}
170
171#[derive(Debug, thiserror::Error)]
172#[error("Could not send or receive from denoise thread. It must have crashed")]
173struct DenoiseEngineCrashed;
174
175impl<S: Source> Denoiser<S> {
176    #[cold]
177    fn prepare_next_ready(&mut self) -> Result<Option<f32>, DenoiseEngineCrashed> {
178        self.state = match self.state {
179            IterState::Startup { enabled } => {
180                // guaranteed to be coming from silence
181                for _ in 0..3 {
182                    let Some(sub_block) = read_sub_block(&mut self.inner) else {
183                        return Ok(None);
184                    };
185                    self.queued.push(sub_block);
186                    self.input_tx
187                        .send(sub_block)
188                        .map_err(|_| DenoiseEngineCrashed)?;
189                }
190                let Some(sub_block) = read_sub_block(&mut self.inner) else {
191                    return Ok(None);
192                };
193                self.queued.push(sub_block);
194                self.input_tx
195                    .send(sub_block)
196                    .map_err(|_| DenoiseEngineCrashed)?;
197                // throw out old blocks that are denoised silence
198                let _ = self.denoised_rx.iter().take(3).count();
199                self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
200
201                let Some(sub_block) = read_sub_block(&mut self.inner) else {
202                    return Ok(None);
203                };
204                self.queued.push(sub_block);
205                self.feed(sub_block);
206
207                if enabled {
208                    IterState::Enabled
209                } else {
210                    IterState::Disabled
211                }
212            }
213            IterState::Enabled => {
214                self.ready = self.denoised_rx.recv().map_err(|_| DenoiseEngineCrashed)?;
215                let Some(sub_block) = read_sub_block(&mut self.inner) else {
216                    return Ok(None);
217                };
218                self.queued.push(sub_block);
219                self.input_tx
220                    .send(sub_block)
221                    .map_err(|_| DenoiseEngineCrashed)?;
222                IterState::Enabled
223            }
224            IterState::Disabled => {
225                // Need to maintain the same 512 samples delay such that
226                // we can re-enable at any point.
227                self.ready = self.queued.pop();
228                let Some(sub_block) = read_sub_block(&mut self.inner) else {
229                    return Ok(None);
230                };
231                self.queued.push(sub_block);
232                IterState::Disabled
233            }
234            IterState::StartingMidAudio {
235                fed_to_denoiser: mut sub_blocks_fed,
236            } => {
237                self.ready = self.queued.pop();
238                let Some(sub_block) = read_sub_block(&mut self.inner) else {
239                    return Ok(None);
240                };
241                self.queued.push(sub_block);
242                self.input_tx
243                    .send(sub_block)
244                    .map_err(|_| DenoiseEngineCrashed)?;
245                sub_blocks_fed += 1;
246                if sub_blocks_fed > 4 {
247                    // throw out partially denoised blocks,
248                    // next will be correctly denoised
249                    let _ = self.denoised_rx.iter().take(3).count();
250                    IterState::Enabled
251                } else {
252                    IterState::StartingMidAudio {
253                        fed_to_denoiser: sub_blocks_fed,
254                    }
255                }
256            }
257        };
258
259        self.next = 0;
260        Ok(Some(self.ready[0]))
261    }
262}
263
264fn read_sub_block(s: &mut impl Source) -> Option<[f32; BLOCK_SHIFT]> {
265    let mut res = [0f32; BLOCK_SHIFT];
266    for sample in &mut res {
267        *sample = s.next()?;
268    }
269    Some(res)
270}