1/// use something like https://netron.app/ to inspect the models and understand
2/// the flow
3use std::collections::HashMap;
4
5use candle_core::{Device, IndexOp, Tensor};
6use candle_onnx::onnx::ModelProto;
7use candle_onnx::prost::Message;
8use realfft::RealFftPlanner;
9use rustfft::num_complex::Complex;
10
11pub struct Engine {
12 spectral_model: ModelProto,
13 signal_model: ModelProto,
14
15 fft_planner: RealFftPlanner<f32>,
16 fft_scratch: Vec<Complex<f32>>,
17 spectrum: [Complex<f32>; FFT_OUT_SIZE],
18 signal: [f32; BLOCK_LEN],
19
20 in_magnitude: [f32; FFT_OUT_SIZE],
21 in_phase: [f32; FFT_OUT_SIZE],
22
23 spectral_memory: Tensor,
24 signal_memory: Tensor,
25
26 in_buffer: [f32; BLOCK_LEN],
27 out_buffer: [f32; BLOCK_LEN],
28}
29
30// 32 ms @ 16khz per DTLN docs: https://github.com/breizhn/DTLN
31pub const BLOCK_LEN: usize = 512;
32// 8 ms @ 16khz per DTLN docs.
33pub const BLOCK_SHIFT: usize = 128;
34pub const FFT_OUT_SIZE: usize = BLOCK_LEN / 2 + 1;
35
36impl Engine {
37 pub fn new() -> Self {
38 let mut fft_planner = RealFftPlanner::new();
39 let fft_planned = fft_planner.plan_fft_forward(BLOCK_LEN);
40 let scratch_len = fft_planned.get_scratch_len();
41 Self {
42 // Models are 1.5MB and 2.5MB respectively. Its worth the binary
43 // size increase not to have to distribute the models separately.
44 spectral_model: ModelProto::decode(
45 include_bytes!("../models/model_1_converted_simplified.onnx").as_slice(),
46 )
47 .expect("The model should decode"),
48 signal_model: ModelProto::decode(
49 include_bytes!("../models/model_2_converted_simplified.onnx").as_slice(),
50 )
51 .expect("The model should decode"),
52 fft_planner,
53 fft_scratch: vec![Complex::ZERO; scratch_len],
54 spectrum: [Complex::ZERO; FFT_OUT_SIZE],
55 signal: [0f32; BLOCK_LEN],
56
57 in_magnitude: [0f32; FFT_OUT_SIZE],
58 in_phase: [0f32; FFT_OUT_SIZE],
59
60 spectral_memory: Tensor::from_slice::<_, f32>(
61 &[0f32; 512],
62 (1, 2, BLOCK_SHIFT, 2),
63 &Device::Cpu,
64 )
65 .expect("Tensor has the correct dimensions"),
66 signal_memory: Tensor::from_slice::<_, f32>(
67 &[0f32; 512],
68 (1, 2, BLOCK_SHIFT, 2),
69 &Device::Cpu,
70 )
71 .expect("Tensor has the correct dimensions"),
72 out_buffer: [0f32; BLOCK_LEN],
73 in_buffer: [0f32; BLOCK_LEN],
74 }
75 }
76
77 /// Add a clunk of samples and get the denoised chunk 4 feeds later
78 pub fn feed(&mut self, samples: &[f32]) -> [f32; BLOCK_SHIFT] {
79 /// The name of the output node of the onnx network
80 /// [Dual-Signal Transformation LSTM Network for Real-Time Noise Suppression](https://arxiv.org/abs/2005.07551).
81 const MEMORY_OUTPUT: &'static str = "Identity_1";
82
83 debug_assert_eq!(samples.len(), BLOCK_SHIFT);
84
85 // place new samples at the end of the `in_buffer`
86 self.in_buffer.copy_within(BLOCK_SHIFT.., 0);
87 self.in_buffer[(BLOCK_LEN - BLOCK_SHIFT)..].copy_from_slice(&samples);
88
89 // run inference
90 let inputs = self.spectral_inputs();
91 let mut spectral_outputs = candle_onnx::simple_eval(&self.spectral_model, inputs)
92 .expect("The embedded file must be valid");
93 self.spectral_memory = spectral_outputs
94 .remove(MEMORY_OUTPUT)
95 .expect("The model has an output named Identity_1");
96 let inputs = self.signal_inputs(spectral_outputs);
97 let mut signal_outputs = candle_onnx::simple_eval(&self.signal_model, inputs)
98 .expect("The embedded file must be valid");
99 self.signal_memory = signal_outputs
100 .remove(MEMORY_OUTPUT)
101 .expect("The model has an output named Identity_1");
102 let model_output = model_outputs(signal_outputs);
103
104 // place processed samples at the start of the `out_buffer`
105 // shift the rest left, fill the end with zeros. Zeros are needed as
106 // the out buffer is part of the input of the network
107 self.out_buffer.copy_within(BLOCK_SHIFT.., 0);
108 self.out_buffer[BLOCK_LEN - BLOCK_SHIFT..].fill(0f32);
109 for (a, b) in self.out_buffer.iter_mut().zip(model_output) {
110 *a += b;
111 }
112
113 // samples at the front of the `out_buffer` are now denoised
114 self.out_buffer[..BLOCK_SHIFT]
115 .try_into()
116 .expect("len is correct")
117 }
118
119 fn spectral_inputs(&mut self) -> HashMap<String, Tensor> {
120 // Prepare FFT input
121 let fft = self.fft_planner.plan_fft_forward(BLOCK_LEN);
122
123 // Perform real-to-complex FFT
124 let mut fft_in = self.in_buffer;
125 fft.process_with_scratch(&mut fft_in, &mut self.spectrum, &mut self.fft_scratch)
126 .expect("The fft should run, there is enough scratch space");
127
128 // Generate magnitude and phase
129 for ((magnitude, phase), complex) in self
130 .in_magnitude
131 .iter_mut()
132 .zip(self.in_phase.iter_mut())
133 .zip(self.spectrum)
134 {
135 *magnitude = complex.norm();
136 *phase = complex.arg();
137 }
138
139 const SPECTRUM_INPUT: &str = "input_2";
140 const MEMORY_INPUT: &str = "input_3";
141 let spectrum =
142 Tensor::from_slice::<_, f32>(&self.in_magnitude, (1, 1, FFT_OUT_SIZE), &Device::Cpu)
143 .expect("the in magnitude has enough elements to fill the Tensor");
144
145 let inputs = HashMap::from([
146 (SPECTRUM_INPUT.to_string(), spectrum),
147 (MEMORY_INPUT.to_string(), self.spectral_memory.clone()),
148 ]);
149 inputs
150 }
151
152 fn signal_inputs(&mut self, outputs: HashMap<String, Tensor>) -> HashMap<String, Tensor> {
153 let magnitude_weight = model_outputs(outputs);
154
155 // Apply mask and reconstruct complex spectrum
156 let mut spectrum = [Complex::I; FFT_OUT_SIZE];
157 for i in 0..FFT_OUT_SIZE {
158 let magnitude = self.in_magnitude[i] * magnitude_weight[i];
159 let phase = self.in_phase[i];
160 let real = magnitude * phase.cos();
161 let imag = magnitude * phase.sin();
162 spectrum[i] = Complex::new(real, imag);
163 }
164
165 // Handle DC component (i = 0)
166 let magnitude = self.in_magnitude[0] * magnitude_weight[0];
167 spectrum[0] = Complex::new(magnitude, 0.0);
168
169 // Handle Nyquist component (i = N/2)
170 let magnitude = self.in_magnitude[FFT_OUT_SIZE - 1] * magnitude_weight[FFT_OUT_SIZE - 1];
171 spectrum[FFT_OUT_SIZE - 1] = Complex::new(magnitude, 0.0);
172
173 // Perform complex-to-real IFFT
174 let ifft = self.fft_planner.plan_fft_inverse(BLOCK_LEN);
175 ifft.process_with_scratch(&mut spectrum, &mut self.signal, &mut self.fft_scratch)
176 .expect("The fft should run, there is enough scratch space");
177
178 // Normalize the IFFT output
179 for real in &mut self.signal {
180 *real /= BLOCK_LEN as f32;
181 }
182
183 const SIGNAL_INPUT: &str = "input_4";
184 const SIGNAL_MEMORY: &str = "input_5";
185 let signal_input =
186 Tensor::from_slice::<_, f32>(&self.signal, (1, 1, BLOCK_LEN), &Device::Cpu).unwrap();
187
188 HashMap::from([
189 (SIGNAL_INPUT.to_string(), signal_input),
190 (SIGNAL_MEMORY.to_string(), self.signal_memory.clone()),
191 ])
192 }
193}
194
195// Both models put their outputs in the same location
196fn model_outputs(mut outputs: HashMap<String, Tensor>) -> Vec<f32> {
197 const NON_MEMORY_OUTPUT: &str = "Identity";
198 outputs
199 .remove(NON_MEMORY_OUTPUT)
200 .expect("The model has this output")
201 .i((0, 0))
202 .and_then(|tensor| tensor.to_vec1())
203 .expect("The tensor has the correct dimensions")
204}