Make every sound go through the webrtc APM for echo cancellation

David Kleingeld created

Also adds a inspect_buffer method to rodio sources trough an
extension trait. We use it to pipe everything trough the apm
echo canceller.

Change summary

Cargo.lock                                                  |   6 
crates/audio/src/audio.rs                                   |  57 +
crates/audio/src/rodio_ext.rs                               | 152 ++++++
crates/livekit_client/src/livekit_client.rs                 |   2 
crates/livekit_client/src/livekit_client/playback.rs        |  95 ++--
crates/livekit_client/src/livekit_client/playback/source.rs |   4 
6 files changed, 238 insertions(+), 78 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -9400,7 +9400,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
 dependencies = [
  "cfg-if",
- "windows-targets 0.48.5",
+ "windows-targets 0.52.6",
 ]
 
 [[package]]
@@ -13856,7 +13856,7 @@ dependencies = [
 [[package]]
 name = "rodio"
 version = "0.21.1"
-source = "git+https://github.com/RustAudio/rodio?branch=microphone#cad73716a363a5ba92fcb73ec37a4b98a7d7de5f"
+source = "git+https://github.com/RustAudio/rodio?branch=microphone#0e6e6436b3a97f4af72baafe11a02ade2b457b62"
 dependencies = [
  "cpal",
  "dasp_sample",
@@ -18943,7 +18943,7 @@ version = "0.1.9"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
 dependencies = [
- "windows-sys 0.48.0",
+ "windows-sys 0.52.0",
 ]
 
 [[package]]

crates/audio/src/audio.rs 🔗

@@ -3,14 +3,29 @@ use collections::HashMap;
 use gpui::{App, BorrowAppContext, Global};
 use libwebrtc::native::apm;
 use parking_lot::Mutex;
-use rodio::{Decoder, OutputStream, OutputStreamBuilder, Source, mixer::Mixer, source::Buffered};
+use rodio::{
+    Decoder, OutputStream, OutputStreamBuilder, Source, cpal::Sample, mixer::Mixer,
+    source::Buffered,
+};
 use settings::Settings;
-use std::{io::Cursor, sync::Arc};
+use std::{io::Cursor, num::NonZero, sync::Arc};
 use util::ResultExt;
 
 mod audio_settings;
 mod rodio_ext;
 pub use audio_settings::AudioSettings;
+pub use rodio_ext::RodioExt;
+
+// NOTE: We use WebRTC's mixer which only supports
+// 16kHz, 32kHz and 48kHz. As 48 is the most common "next step up"
+// for audio output devices like speakers/bluetooth, we just hard-code
+// this; and downsample when we need to.
+//
+// Since most noise cancelling requires 16kHz we will move to
+// that in the future. Same for channel count. That should be input
+// channels and fixed to 1.
+pub const SAMPLE_RATE: NonZero<u32> = NonZero::new(48000).expect("not zero");
+pub const CHANNEL_COUNT: NonZero<u16> = NonZero::new(2).expect("not zero");
 
 pub fn init(cx: &mut App) {
     AudioSettings::register(cx);
@@ -44,7 +59,7 @@ impl Sound {
 pub struct Audio {
     output_handle: Option<OutputStream>,
     output_mixer: Option<Mixer>,
-    echo_canceller: Arc<Mutex<apm::AudioProcessingModule>>,
+    pub echo_canceller: Arc<Mutex<apm::AudioProcessingModule>>,
     source_cache: HashMap<Sound, Buffered<Decoder<Cursor<Vec<u8>>>>>,
 }
 
@@ -64,26 +79,32 @@ impl Default for Audio {
 impl Global for Audio {}
 
 impl Audio {
-    fn ensure_output_exists(&mut self) -> Option<&OutputStream> {
+    fn ensure_output_exists(&mut self) -> Option<&Mixer> {
         if self.output_handle.is_none() {
             self.output_handle = OutputStreamBuilder::open_default_stream().log_err();
-            if let Some(output_handle) = self.output_handle {
-                let config = output_handle.config();
-                let (mixer, source) =
-                    rodio::mixer::mixer(config.channel_count(), config.sample_rate());
+            if let Some(output_handle) = &self.output_handle {
+                let (mixer, source) = rodio::mixer::mixer(CHANNEL_COUNT, SAMPLE_RATE);
                 self.output_mixer = Some(mixer);
 
                 let echo_canceller = Arc::clone(&self.echo_canceller);
-                let source = source.inspect_buffered(
-                    |buffer| echo_canceller.lock().process_reverse_stream(&mut buf),
-                    config.sample_rate().get() as i32,
-                    config.channel_count().get().into(),
-                );
+                const BUFFER_SIZE: usize = // echo canceller wants 10ms of audio
+                    (SAMPLE_RATE.get() as usize / 100) * CHANNEL_COUNT.get() as usize;
+                let source = source.inspect_buffer::<BUFFER_SIZE, _>(move |buffer| {
+                    let mut buf: [i16; _] = buffer.map(|s| s.to_sample());
+                    echo_canceller
+                        .lock()
+                        .process_reverse_stream(
+                            &mut buf,
+                            SAMPLE_RATE.get() as i32,
+                            CHANNEL_COUNT.get().into(),
+                        )
+                        .expect("Audio input and output threads should not panic");
+                });
                 output_handle.mixer().add(source);
             }
         }
 
-        self.output_handle.as_ref()
+        self.output_mixer.as_ref()
     }
 
     pub fn play_source(
@@ -91,10 +112,10 @@ impl Audio {
         cx: &mut App,
     ) -> anyhow::Result<()> {
         cx.update_default_global(|this: &mut Self, _cx| {
-            let output_handle = this
+            let output_mixer = this
                 .ensure_output_exists()
                 .ok_or_else(|| anyhow!("Could not open audio output"))?;
-            output_handle.mixer().add(source);
+            output_mixer.add(source);
             Ok(())
         })
     }
@@ -102,9 +123,9 @@ impl Audio {
     pub fn play_sound(sound: Sound, cx: &mut App) {
         cx.update_default_global(|this: &mut Self, cx| {
             let source = this.sound_source(sound, cx).log_err()?;
-            let output_handle = this.ensure_output_exists()?;
+            let output_mixer = this.ensure_output_exists()?;
 
-            output_handle.mixer().add(source);
+            output_mixer.add(source);
             Some(())
         });
     }

crates/audio/src/rodio_ext.rs 🔗

@@ -4,7 +4,7 @@ pub trait RodioExt: Source + Sized {
     fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
     where
         F: FnMut(&mut [rodio::Sample; N]);
-    fn inspect_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
+    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
     where
         F: FnMut(&[rodio::Sample; N]);
 }
@@ -21,7 +21,7 @@ impl<S: Source> RodioExt for S {
             next: N,
         }
     }
-    fn inspect_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
+    fn inspect_buffer<const N: usize, F>(self, callback: F) -> InspectBuffer<N, Self, F>
     where
         F: FnMut(&[rodio::Sample; N]),
     {
@@ -29,7 +29,7 @@ impl<S: Source> RodioExt for S {
             inner: self,
             callback,
             buffer: [0.0; N],
-            next: N,
+            free: 0,
         }
     }
 }
@@ -41,7 +41,13 @@ where
 {
     inner: S,
     callback: F,
+    /// Buffer used for both input and output.
     buffer: [rodio::Sample; N],
+    /// Next already processed sample is at this index
+    /// in buffer.
+    ///
+    /// If this is equal to the length of the buffer we have no more samples and
+    /// we must get new ones and process them
     next: usize,
 }
 
@@ -91,3 +97,143 @@ where
         self.inner.total_duration()
     }
 }
+
+pub struct InspectBuffer<const N: usize, S, F>
+where
+    S: Source + Sized,
+    F: FnMut(&[rodio::Sample; N]),
+{
+    inner: S,
+    callback: F,
+    /// Stores already emitted samples, once its full we call the callback.
+    buffer: [rodio::Sample; N],
+    /// Next free element in buffer. If this is equal to the buffer length
+    /// we have no more free lements.
+    free: usize,
+}
+
+impl<const N: usize, S, F> Iterator for InspectBuffer<N, S, F>
+where
+    S: Source + Sized,
+    F: FnMut(&[rodio::Sample; N]),
+{
+    type Item = rodio::Sample;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        let Some(sample) = self.inner.next() else {
+            return None;
+        };
+
+        self.buffer[self.free] = sample;
+        self.free += 1;
+
+        if self.free == self.buffer.len() {
+            (self.callback)(&self.buffer);
+            self.free = 0
+        }
+
+        Some(sample)
+    }
+}
+
+impl<const N: usize, S, F> Source for InspectBuffer<N, S, F>
+where
+    S: Source + Sized,
+    F: FnMut(&[rodio::Sample; N]),
+{
+    fn current_span_len(&self) -> Option<usize> {
+        // TODO dvdsk this should be a spanless Source
+        None
+    }
+
+    fn channels(&self) -> rodio::ChannelCount {
+        self.inner.channels()
+    }
+
+    fn sample_rate(&self) -> rodio::SampleRate {
+        self.inner.sample_rate()
+    }
+
+    fn total_duration(&self) -> Option<std::time::Duration> {
+        self.inner.total_duration()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use rodio::static_buffer::StaticSamplesBuffer;
+
+    use super::*;
+
+    #[cfg(test)]
+    mod process_buffer {
+        use super::*;
+
+        #[test]
+        fn callback_gets_all_samples() {
+            const SAMPLES: [f32; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
+            let input =
+                StaticSamplesBuffer::new(1.try_into().unwrap(), 1.try_into().unwrap(), &SAMPLES);
+
+            let _ = input
+                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
+                .count();
+        }
+        #[test]
+        fn callback_modifies_yielded() {
+            const SAMPLES: [f32; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
+            let input =
+                StaticSamplesBuffer::new(1.try_into().unwrap(), 1.try_into().unwrap(), &SAMPLES);
+
+            let yielded: Vec<_> = input
+                .process_buffer::<{ SAMPLES.len() }, _>(|buffer| {
+                    for sample in buffer {
+                        *sample += 1.0;
+                    }
+                })
+                .collect();
+            assert_eq!(
+                yielded,
+                SAMPLES.into_iter().map(|s| s + 1.0).collect::<Vec<_>>()
+            )
+        }
+        #[test]
+        fn source_truncates_to_whole_buffers() {
+            const SAMPLES: [f32; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
+            let input =
+                StaticSamplesBuffer::new(1.try_into().unwrap(), 1.try_into().unwrap(), &SAMPLES);
+
+            let yielded = input
+                .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
+                .count();
+            assert_eq!(yielded, 3)
+        }
+    }
+
+    #[cfg(test)]
+    mod inspect_buffer {
+        use super::*;
+
+        #[test]
+        fn callback_gets_all_samples() {
+            const SAMPLES: [f32; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
+            let input =
+                StaticSamplesBuffer::new(1.try_into().unwrap(), 1.try_into().unwrap(), &SAMPLES);
+
+            let _ = input
+                .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES))
+                .count();
+        }
+        #[test]
+        fn source_does_not_truncate() {
+            const SAMPLES: [f32; 5] = [0.0, 1.0, 2.0, 3.0, 4.0];
+            let input =
+                StaticSamplesBuffer::new(1.try_into().unwrap(), 1.try_into().unwrap(), &SAMPLES);
+
+            let yielded = input
+                .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3]))
+                .count();
+            assert_eq!(yielded, SAMPLES.len())
+        }
+    }
+}

crates/livekit_client/src/livekit_client.rs 🔗

@@ -129,7 +129,7 @@ impl Room {
         cx: &mut App,
     ) -> Result<playback::AudioStream> {
         if AudioSettings::get_global(cx).rodio_audio {
-            info!("Using experimental.rodio_audio audio pipeline");
+            info!("Using experimental.rodio_audio audio pipeline for output");
             playback::play_remote_audio_track(&track.0, cx)
         } else {
             Ok(self.playback.play_remote_audio_track(&track.0))

crates/livekit_client/src/livekit_client/playback.rs 🔗

@@ -1,6 +1,6 @@
 use anyhow::{Context as _, Result};
 
-use audio::AudioSettings;
+use audio::{AudioSettings, CHANNEL_COUNT, SAMPLE_RATE};
 use cpal::Sample;
 use cpal::traits::{DeviceTrait, StreamTrait as _};
 use dasp_sample::ToSample;
@@ -45,13 +45,6 @@ pub(crate) struct AudioStack {
     next_ssrc: AtomicI32,
 }
 
-// NOTE: We use WebRTC's mixer which only supports
-// 16kHz, 32kHz and 48kHz. As 48 is the most common "next step up"
-// for audio output devices like speakers/bluetooth, we just hard-code
-// this; and downsample when we need to.
-const SAMPLE_RATE: NonZero<u32> = NonZero::new(48000).expect("not zero");
-const NUM_CHANNELS: NonZero<u16> = NonZero::new(2).expect("not zero");
-
 pub(crate) fn play_remote_audio_track(
     track: &livekit::track::RemoteAudioTrack,
     cx: &mut gpui::App,
@@ -59,7 +52,6 @@ pub(crate) fn play_remote_audio_track(
     let stop_handle = Arc::new(AtomicBool::new(false));
     let stop_handle_clone = stop_handle.clone();
     let stream = source::LiveKitStream::new(cx.background_executor(), track)
-        .process_buffer(|| apm)
         .stoppable()
         .periodic_access(Duration::from_millis(50), move |s| {
             if stop_handle.load(Ordering::Relaxed) {
@@ -101,7 +93,7 @@ impl AudioStack {
         let source = AudioMixerSource {
             ssrc: next_ssrc,
             sample_rate: SAMPLE_RATE.get(),
-            num_channels: NUM_CHANNELS.get() as u32,
+            num_channels: CHANNEL_COUNT.get() as u32,
             buffer: Arc::default(),
         };
         self.mixer.lock().add_source(source.clone());
@@ -141,7 +133,7 @@ impl AudioStack {
             let apm = self.apm.clone();
             let mixer = self.mixer.clone();
             async move {
-                Self::play_output(apm, mixer, SAMPLE_RATE.get(), NUM_CHANNELS.get().into())
+                Self::play_output(apm, mixer, SAMPLE_RATE.get(), CHANNEL_COUNT.get().into())
                     .await
                     .log_err();
             }
@@ -158,7 +150,7 @@ impl AudioStack {
             // n.b. this struct's options are always ignored, noise cancellation is provided by apm.
             AudioSourceOptions::default(),
             SAMPLE_RATE.get(),
-            NUM_CHANNELS.get().into(),
+            CHANNEL_COUNT.get().into(),
             10,
         );
 
@@ -178,16 +170,21 @@ impl AudioStack {
             }
         });
         let rodio_pipeline =
-            AudioSettings::try_read_global(cx, |setting| setting.rodio_audio).unwrap_or(false);
-        let capture_task = self.executor.spawn(async move {
-            if rodio_pipeline {
+            AudioSettings::try_read_global(cx, |setting| setting.rodio_audio).unwrap_or_default();
+        let capture_task = if rodio_pipeline {
+            let apm = cx
+                .try_read_global::<audio::Audio, _>(|audio, _| Arc::clone(&audio.echo_canceller))
+                .unwrap(); // TODO fixme
+            self.executor.spawn(async move {
                 info!("Using experimental.rodio_audio audio pipeline");
                 Self::capture_input_rodio(apm, frame_tx).await
-            } else {
-                Self::capture_input(apm, frame_tx, SAMPLE_RATE.get(), NUM_CHANNELS.get().into())
+            })
+        } else {
+            self.executor.spawn(async move {
+                Self::capture_input(apm, frame_tx, SAMPLE_RATE.get(), CHANNEL_COUNT.get().into())
                     .await
-            }
-        });
+            })
+        };
 
         let on_drop = util::defer(|| {
             drop(transmit_task);
@@ -277,10 +274,9 @@ impl AudioStack {
         apm: Arc<Mutex<apm::AudioProcessingModule>>,
         frame_tx: UnboundedSender<AudioFrame<'static>>,
     ) -> Result<()> {
-        use crate::livekit_client::playback::source::RodioExt;
-        const NUM_CHANNELS: usize = 1;
+        use audio::RodioExt;
         const LIVEKIT_BUFFER_SIZE: usize =
-            (SAMPLE_RATE.get() as usize / 100) * NUM_CHANNELS as usize;
+            (audio::SAMPLE_RATE.get() as usize / 100) * audio::CHANNEL_COUNT.get() as usize;
 
         let (stream_error_tx, stream_error_rx) = channel();
 
@@ -298,31 +294,29 @@ impl AudioStack {
                 ])
                 .prefer_buffer_sizes(512..)
                 .open_stream()?;
-            let mut stream = UniformSourceIterator::new(
-                stream,
-                NonZero::new(1).expect("1 is not zero"),
-                SAMPLE_RATE,
-            )
-            .limit(LimitSettings::live_performance())
-            .process_buffer::<LIVEKIT_BUFFER_SIZE, _>(|buffer| {
-                let mut int_buffer: [i16; _] = buffer.map(|s| s.to_sample());
-                if let Err(e) = apm
-                    .lock()
-                    .process_stream(
-                        &mut int_buffer,
-                        SAMPLE_RATE.get() as i32,
-                        NUM_CHANNELS as i32,
-                    )
-                    .context("livekit audio processor error")
-                {
-                    let _ = stream_error_tx.send(e);
-                } else {
-                    for (sample, processed) in buffer.iter_mut().zip(&int_buffer) {
-                        *sample = (*processed).to_sample_();
-                    }
-                }
-            })
-            .automatic_gain_control(1.0, 4.0, 0.0, 5.0);
+            info!("Opened microphone: {:?}", stream.config());
+            let mut stream =
+                UniformSourceIterator::new(stream, audio::CHANNEL_COUNT, audio::SAMPLE_RATE)
+                    .limit(LimitSettings::live_performance())
+                    .process_buffer::<LIVEKIT_BUFFER_SIZE, _>(|buffer| {
+                        let mut int_buffer: [i16; _] = buffer.map(|s| s.to_sample());
+                        if let Err(e) = apm
+                            .lock()
+                            .process_stream(
+                                &mut int_buffer,
+                                audio::SAMPLE_RATE.get() as i32,
+                                audio::CHANNEL_COUNT.get() as i32,
+                            )
+                            .context("livekit audio processor error")
+                        {
+                            let _ = stream_error_tx.send(e);
+                        } else {
+                            for (sample, processed) in buffer.iter_mut().zip(&int_buffer) {
+                                *sample = (*processed).to_sample_();
+                            }
+                        }
+                    })
+                    .automatic_gain_control(1.0, 4.0, 0.0, 5.0);
 
             loop {
                 let sampled: Vec<_> = stream
@@ -342,8 +336,9 @@ impl AudioStack {
                 frame_tx
                     .unbounded_send(AudioFrame {
                         sample_rate: SAMPLE_RATE.get(),
-                        num_channels: NUM_CHANNELS as u32,
-                        samples_per_channel: sampled.len() as u32 / NUM_CHANNELS as u32,
+                        num_channels: audio::CHANNEL_COUNT.get() as u32,
+                        samples_per_channel: sampled.len() as u32
+                            / audio::CHANNEL_COUNT.get() as u32,
                         data: Cow::Owned(sampled),
                     })
                     .context("Failed to send audio frame")?
@@ -445,8 +440,6 @@ impl AudioStack {
     }
 }
 
-use crate::livekit_client::playback::source::RodioExt;
-
 use super::LocalVideoTrack;
 
 pub enum AudioStream {

crates/livekit_client/src/livekit_client/playback/source.rs 🔗

@@ -5,7 +5,7 @@ use libwebrtc::{audio_stream::native::NativeAudioStream, prelude::AudioFrame};
 use livekit::track::RemoteAudioTrack;
 use rodio::{Source, buffer::SamplesBuffer, conversions::SampleTypeConverter};
 
-use crate::livekit_client::playback::{NUM_CHANNELS, SAMPLE_RATE};
+use crate::livekit_client::playback::{CHANNEL_COUNT, SAMPLE_RATE};
 
 fn frame_to_samplesbuffer(frame: AudioFrame) -> SamplesBuffer {
     let samples = frame.data.iter().copied();
@@ -29,7 +29,7 @@ impl LiveKitStream {
         let mut stream = NativeAudioStream::new(
             track.rtc_track(),
             SAMPLE_RATE.get() as i32,
-            NUM_CHANNELS.get().into(),
+            CHANNEL_COUNT.get().into(),
         );
         let (queue_input, queue_output) = rodio::queue::queue(true);
         // spawn rtc stream