engine.rs

  1/// use something like https://netron.app/ to inspect the models and understand
  2/// the flow
  3use std::collections::HashMap;
  4
  5use candle_core::{Device, IndexOp, Tensor};
  6use candle_onnx::onnx::ModelProto;
  7use candle_onnx::prost::Message;
  8use realfft::RealFftPlanner;
  9use rustfft::num_complex::Complex;
 10
 11pub struct Engine {
 12    spectral_model: ModelProto,
 13    signal_model: ModelProto,
 14
 15    fft_planner: RealFftPlanner<f32>,
 16    fft_scratch: Vec<Complex<f32>>,
 17    spectrum: [Complex<f32>; FFT_OUT_SIZE],
 18    signal: [f32; BLOCK_LEN],
 19
 20    in_magnitude: [f32; FFT_OUT_SIZE],
 21    in_phase: [f32; FFT_OUT_SIZE],
 22
 23    spectral_memory: Tensor,
 24    signal_memory: Tensor,
 25
 26    in_buffer: [f32; BLOCK_LEN],
 27    out_buffer: [f32; BLOCK_LEN],
 28}
 29
 30// 32 ms @ 16khz per DTLN docs: https://github.com/breizhn/DTLN
 31pub const BLOCK_LEN: usize = 512;
 32// 8 ms @ 16khz per DTLN docs.
 33pub const BLOCK_SHIFT: usize = 128;
 34pub const FFT_OUT_SIZE: usize = BLOCK_LEN / 2 + 1;
 35
 36impl Engine {
 37    pub fn new() -> Self {
 38        let mut fft_planner = RealFftPlanner::new();
 39        let fft_planned = fft_planner.plan_fft_forward(BLOCK_LEN);
 40        let scratch_len = fft_planned.get_scratch_len();
 41        Self {
 42            // Models are 1.5MB and 2.5MB respectively. Its worth the binary
 43            // size increase not to have to distribute the models separately.
 44            spectral_model: ModelProto::decode(
 45                include_bytes!("../models/model_1_converted_simplified.onnx").as_slice(),
 46            )
 47            .expect("The model should decode"),
 48            signal_model: ModelProto::decode(
 49                include_bytes!("../models/model_2_converted_simplified.onnx").as_slice(),
 50            )
 51            .expect("The model should decode"),
 52            fft_planner,
 53            fft_scratch: vec![Complex::ZERO; scratch_len],
 54            spectrum: [Complex::ZERO; FFT_OUT_SIZE],
 55            signal: [0f32; BLOCK_LEN],
 56
 57            in_magnitude: [0f32; FFT_OUT_SIZE],
 58            in_phase: [0f32; FFT_OUT_SIZE],
 59
 60            spectral_memory: Tensor::from_slice::<_, f32>(
 61                &[0f32; 512],
 62                (1, 2, BLOCK_SHIFT, 2),
 63                &Device::Cpu,
 64            )
 65            .expect("Tensor has the correct dimensions"),
 66            signal_memory: Tensor::from_slice::<_, f32>(
 67                &[0f32; 512],
 68                (1, 2, BLOCK_SHIFT, 2),
 69                &Device::Cpu,
 70            )
 71            .expect("Tensor has the correct dimensions"),
 72            out_buffer: [0f32; BLOCK_LEN],
 73            in_buffer: [0f32; BLOCK_LEN],
 74        }
 75    }
 76
 77    /// Add a clunk of samples and get the denoised chunk 4 feeds later
 78    pub fn feed(&mut self, samples: &[f32]) -> [f32; BLOCK_SHIFT] {
 79        /// The name of the output node of the onnx network
 80        /// [Dual-Signal Transformation LSTM Network for Real-Time Noise Suppression](https://arxiv.org/abs/2005.07551).
 81        const MEMORY_OUTPUT: &'static str = "Identity_1";
 82
 83        debug_assert_eq!(samples.len(), BLOCK_SHIFT);
 84
 85        // place new samples at the end of the `in_buffer`
 86        self.in_buffer.copy_within(BLOCK_SHIFT.., 0);
 87        self.in_buffer[(BLOCK_LEN - BLOCK_SHIFT)..].copy_from_slice(&samples);
 88
 89        // run inference
 90        let inputs = self.spectral_inputs();
 91        let mut spectral_outputs = candle_onnx::simple_eval(&self.spectral_model, inputs)
 92            .expect("The embedded file must be valid");
 93        self.spectral_memory = spectral_outputs
 94            .remove(MEMORY_OUTPUT)
 95            .expect("The model has an output named Identity_1");
 96        let inputs = self.signal_inputs(spectral_outputs);
 97        let mut signal_outputs = candle_onnx::simple_eval(&self.signal_model, inputs)
 98            .expect("The embedded file must be valid");
 99        self.signal_memory = signal_outputs
100            .remove(MEMORY_OUTPUT)
101            .expect("The model has an output named Identity_1");
102        let model_output = model_outputs(signal_outputs);
103
104        // place processed samples at the start of the `out_buffer`
105        // shift the rest left, fill the end with zeros. Zeros are needed as
106        // the out buffer is part of the input of the network
107        self.out_buffer.copy_within(BLOCK_SHIFT.., 0);
108        self.out_buffer[BLOCK_LEN - BLOCK_SHIFT..].fill(0f32);
109        for (a, b) in self.out_buffer.iter_mut().zip(model_output) {
110            *a += b;
111        }
112
113        // samples at the front of the `out_buffer` are now denoised
114        self.out_buffer[..BLOCK_SHIFT]
115            .try_into()
116            .expect("len is correct")
117    }
118
119    fn spectral_inputs(&mut self) -> HashMap<String, Tensor> {
120        // Prepare FFT input
121        let fft = self.fft_planner.plan_fft_forward(BLOCK_LEN);
122
123        // Perform real-to-complex FFT
124        let mut fft_in = self.in_buffer;
125        fft.process_with_scratch(&mut fft_in, &mut self.spectrum, &mut self.fft_scratch)
126            .expect("The fft should run, there is enough scratch space");
127
128        // Generate magnitude and phase
129        for ((magnitude, phase), complex) in self
130            .in_magnitude
131            .iter_mut()
132            .zip(self.in_phase.iter_mut())
133            .zip(self.spectrum)
134        {
135            *magnitude = complex.norm();
136            *phase = complex.arg();
137        }
138
139        const SPECTRUM_INPUT: &str = "input_2";
140        const MEMORY_INPUT: &str = "input_3";
141        let spectrum =
142            Tensor::from_slice::<_, f32>(&self.in_magnitude, (1, 1, FFT_OUT_SIZE), &Device::Cpu)
143                .expect("the in magnitude has enough elements to fill the Tensor");
144
145        let inputs = HashMap::from([
146            (SPECTRUM_INPUT.to_string(), spectrum),
147            (MEMORY_INPUT.to_string(), self.spectral_memory.clone()),
148        ]);
149        inputs
150    }
151
152    fn signal_inputs(&mut self, outputs: HashMap<String, Tensor>) -> HashMap<String, Tensor> {
153        let magnitude_weight = model_outputs(outputs);
154
155        // Apply mask and reconstruct complex spectrum
156        let mut spectrum = [Complex::I; FFT_OUT_SIZE];
157        for i in 0..FFT_OUT_SIZE {
158            let magnitude = self.in_magnitude[i] * magnitude_weight[i];
159            let phase = self.in_phase[i];
160            let real = magnitude * phase.cos();
161            let imag = magnitude * phase.sin();
162            spectrum[i] = Complex::new(real, imag);
163        }
164
165        // Handle DC component (i = 0)
166        let magnitude = self.in_magnitude[0] * magnitude_weight[0];
167        spectrum[0] = Complex::new(magnitude, 0.0);
168
169        // Handle Nyquist component (i = N/2)
170        let magnitude = self.in_magnitude[FFT_OUT_SIZE - 1] * magnitude_weight[FFT_OUT_SIZE - 1];
171        spectrum[FFT_OUT_SIZE - 1] = Complex::new(magnitude, 0.0);
172
173        // Perform complex-to-real IFFT
174        let ifft = self.fft_planner.plan_fft_inverse(BLOCK_LEN);
175        ifft.process_with_scratch(&mut spectrum, &mut self.signal, &mut self.fft_scratch)
176            .expect("The fft should run, there is enough scratch space");
177
178        // Normalize the IFFT output
179        for real in &mut self.signal {
180            *real /= BLOCK_LEN as f32;
181        }
182
183        const SIGNAL_INPUT: &str = "input_4";
184        const SIGNAL_MEMORY: &str = "input_5";
185        let signal_input =
186            Tensor::from_slice::<_, f32>(&self.signal, (1, 1, BLOCK_LEN), &Device::Cpu).unwrap();
187
188        HashMap::from([
189            (SIGNAL_INPUT.to_string(), signal_input),
190            (SIGNAL_MEMORY.to_string(), self.signal_memory.clone()),
191        ])
192    }
193}
194
195// Both models put their outputs in the same location
196fn model_outputs(mut outputs: HashMap<String, Tensor>) -> Vec<f32> {
197    const NON_MEMORY_OUTPUT: &str = "Identity";
198    outputs
199        .remove(NON_MEMORY_OUTPUT)
200        .expect("The model has this output")
201        .i((0, 0))
202        .and_then(|tensor| tensor.to_vec1())
203        .expect("The tensor has the correct dimensions")
204}