From d35f1c0b63d861d4db3030aa92766971a5aa067a Mon Sep 17 00:00:00 2001 From: David Kleingeld Date: Wed, 27 Aug 2025 16:18:11 +0200 Subject: [PATCH] Make every sound go through the webrtc APM for echo cancellation Also adds a inspect_buffer method to rodio sources trough an extension trait. We use it to pipe everything trough the apm echo canceller. --- Cargo.lock | 6 +- crates/audio/src/audio.rs | 57 ++++--- crates/audio/src/rodio_ext.rs | 152 +++++++++++++++++- crates/livekit_client/src/livekit_client.rs | 2 +- .../src/livekit_client/playback.rs | 95 +++++------ .../src/livekit_client/playback/source.rs | 4 +- 6 files changed, 238 insertions(+), 78 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index baad57d52298c9ae88297369d771eb6fe7f63829..63561664da3f951657484e880416e4ace80ce0b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9400,7 +9400,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -13856,7 +13856,7 @@ dependencies = [ [[package]] name = "rodio" version = "0.21.1" -source = "git+https://github.com/RustAudio/rodio?branch=microphone#cad73716a363a5ba92fcb73ec37a4b98a7d7de5f" +source = "git+https://github.com/RustAudio/rodio?branch=microphone#0e6e6436b3a97f4af72baafe11a02ade2b457b62" dependencies = [ "cpal", "dasp_sample", @@ -18943,7 +18943,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] diff --git a/crates/audio/src/audio.rs b/crates/audio/src/audio.rs index a5a4e721d345f30b74276c1c1d6d6f6358c549d9..ee046e6fde3a791ca7bdfde138e6cacda75d0576 100644 --- a/crates/audio/src/audio.rs +++ b/crates/audio/src/audio.rs @@ -3,14 +3,29 @@ use collections::HashMap; use gpui::{App, BorrowAppContext, Global}; use libwebrtc::native::apm; use parking_lot::Mutex; -use rodio::{Decoder, OutputStream, OutputStreamBuilder, Source, mixer::Mixer, source::Buffered}; +use rodio::{ + Decoder, OutputStream, OutputStreamBuilder, Source, cpal::Sample, mixer::Mixer, + source::Buffered, +}; use settings::Settings; -use std::{io::Cursor, sync::Arc}; +use std::{io::Cursor, num::NonZero, sync::Arc}; use util::ResultExt; mod audio_settings; mod rodio_ext; pub use audio_settings::AudioSettings; +pub use rodio_ext::RodioExt; + +// NOTE: We use WebRTC's mixer which only supports +// 16kHz, 32kHz and 48kHz. As 48 is the most common "next step up" +// for audio output devices like speakers/bluetooth, we just hard-code +// this; and downsample when we need to. +// +// Since most noise cancelling requires 16kHz we will move to +// that in the future. Same for channel count. That should be input +// channels and fixed to 1. +pub const SAMPLE_RATE: NonZero = NonZero::new(48000).expect("not zero"); +pub const CHANNEL_COUNT: NonZero = NonZero::new(2).expect("not zero"); pub fn init(cx: &mut App) { AudioSettings::register(cx); @@ -44,7 +59,7 @@ impl Sound { pub struct Audio { output_handle: Option, output_mixer: Option, - echo_canceller: Arc>, + pub echo_canceller: Arc>, source_cache: HashMap>>>>, } @@ -64,26 +79,32 @@ impl Default for Audio { impl Global for Audio {} impl Audio { - fn ensure_output_exists(&mut self) -> Option<&OutputStream> { + fn ensure_output_exists(&mut self) -> Option<&Mixer> { if self.output_handle.is_none() { self.output_handle = OutputStreamBuilder::open_default_stream().log_err(); - if let Some(output_handle) = self.output_handle { - let config = output_handle.config(); - let (mixer, source) = - rodio::mixer::mixer(config.channel_count(), config.sample_rate()); + if let Some(output_handle) = &self.output_handle { + let (mixer, source) = rodio::mixer::mixer(CHANNEL_COUNT, SAMPLE_RATE); self.output_mixer = Some(mixer); let echo_canceller = Arc::clone(&self.echo_canceller); - let source = source.inspect_buffered( - |buffer| echo_canceller.lock().process_reverse_stream(&mut buf), - config.sample_rate().get() as i32, - config.channel_count().get().into(), - ); + const BUFFER_SIZE: usize = // echo canceller wants 10ms of audio + (SAMPLE_RATE.get() as usize / 100) * CHANNEL_COUNT.get() as usize; + let source = source.inspect_buffer::(move |buffer| { + let mut buf: [i16; _] = buffer.map(|s| s.to_sample()); + echo_canceller + .lock() + .process_reverse_stream( + &mut buf, + SAMPLE_RATE.get() as i32, + CHANNEL_COUNT.get().into(), + ) + .expect("Audio input and output threads should not panic"); + }); output_handle.mixer().add(source); } } - self.output_handle.as_ref() + self.output_mixer.as_ref() } pub fn play_source( @@ -91,10 +112,10 @@ impl Audio { cx: &mut App, ) -> anyhow::Result<()> { cx.update_default_global(|this: &mut Self, _cx| { - let output_handle = this + let output_mixer = this .ensure_output_exists() .ok_or_else(|| anyhow!("Could not open audio output"))?; - output_handle.mixer().add(source); + output_mixer.add(source); Ok(()) }) } @@ -102,9 +123,9 @@ impl Audio { pub fn play_sound(sound: Sound, cx: &mut App) { cx.update_default_global(|this: &mut Self, cx| { let source = this.sound_source(sound, cx).log_err()?; - let output_handle = this.ensure_output_exists()?; + let output_mixer = this.ensure_output_exists()?; - output_handle.mixer().add(source); + output_mixer.add(source); Some(()) }); } diff --git a/crates/audio/src/rodio_ext.rs b/crates/audio/src/rodio_ext.rs index 055ff9003ab483db9f0364bc50cfbe707dd21275..23e85a4163ad64a8f03e7bbea8a5cdd6f1918d8d 100644 --- a/crates/audio/src/rodio_ext.rs +++ b/crates/audio/src/rodio_ext.rs @@ -4,7 +4,7 @@ pub trait RodioExt: Source + Sized { fn process_buffer(self, callback: F) -> ProcessBuffer where F: FnMut(&mut [rodio::Sample; N]); - fn inspect_buffer(self, callback: F) -> ProcessBuffer + fn inspect_buffer(self, callback: F) -> InspectBuffer where F: FnMut(&[rodio::Sample; N]); } @@ -21,7 +21,7 @@ impl RodioExt for S { next: N, } } - fn inspect_buffer(self, callback: F) -> ProcessBuffer + fn inspect_buffer(self, callback: F) -> InspectBuffer where F: FnMut(&[rodio::Sample; N]), { @@ -29,7 +29,7 @@ impl RodioExt for S { inner: self, callback, buffer: [0.0; N], - next: N, + free: 0, } } } @@ -41,7 +41,13 @@ where { inner: S, callback: F, + /// Buffer used for both input and output. buffer: [rodio::Sample; N], + /// Next already processed sample is at this index + /// in buffer. + /// + /// If this is equal to the length of the buffer we have no more samples and + /// we must get new ones and process them next: usize, } @@ -91,3 +97,143 @@ where self.inner.total_duration() } } + +pub struct InspectBuffer +where + S: Source + Sized, + F: FnMut(&[rodio::Sample; N]), +{ + inner: S, + callback: F, + /// Stores already emitted samples, once its full we call the callback. + buffer: [rodio::Sample; N], + /// Next free element in buffer. If this is equal to the buffer length + /// we have no more free lements. + free: usize, +} + +impl Iterator for InspectBuffer +where + S: Source + Sized, + F: FnMut(&[rodio::Sample; N]), +{ + type Item = rodio::Sample; + + fn next(&mut self) -> Option { + let Some(sample) = self.inner.next() else { + return None; + }; + + self.buffer[self.free] = sample; + self.free += 1; + + if self.free == self.buffer.len() { + (self.callback)(&self.buffer); + self.free = 0 + } + + Some(sample) + } +} + +impl Source for InspectBuffer +where + S: Source + Sized, + F: FnMut(&[rodio::Sample; N]), +{ + fn current_span_len(&self) -> Option { + // TODO dvdsk this should be a spanless Source + None + } + + fn channels(&self) -> rodio::ChannelCount { + self.inner.channels() + } + + fn sample_rate(&self) -> rodio::SampleRate { + self.inner.sample_rate() + } + + fn total_duration(&self) -> Option { + self.inner.total_duration() + } +} + +#[cfg(test)] +mod tests { + use rodio::static_buffer::StaticSamplesBuffer; + + use super::*; + + #[cfg(test)] + mod process_buffer { + use super::*; + + #[test] + fn callback_gets_all_samples() { + const SAMPLES: [f32; 5] = [0.0, 1.0, 2.0, 3.0, 4.0]; + let input = + StaticSamplesBuffer::new(1.try_into().unwrap(), 1.try_into().unwrap(), &SAMPLES); + + let _ = input + .process_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES)) + .count(); + } + #[test] + fn callback_modifies_yielded() { + const SAMPLES: [f32; 5] = [0.0, 1.0, 2.0, 3.0, 4.0]; + let input = + StaticSamplesBuffer::new(1.try_into().unwrap(), 1.try_into().unwrap(), &SAMPLES); + + let yielded: Vec<_> = input + .process_buffer::<{ SAMPLES.len() }, _>(|buffer| { + for sample in buffer { + *sample += 1.0; + } + }) + .collect(); + assert_eq!( + yielded, + SAMPLES.into_iter().map(|s| s + 1.0).collect::>() + ) + } + #[test] + fn source_truncates_to_whole_buffers() { + const SAMPLES: [f32; 5] = [0.0, 1.0, 2.0, 3.0, 4.0]; + let input = + StaticSamplesBuffer::new(1.try_into().unwrap(), 1.try_into().unwrap(), &SAMPLES); + + let yielded = input + .process_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3])) + .count(); + assert_eq!(yielded, 3) + } + } + + #[cfg(test)] + mod inspect_buffer { + use super::*; + + #[test] + fn callback_gets_all_samples() { + const SAMPLES: [f32; 5] = [0.0, 1.0, 2.0, 3.0, 4.0]; + let input = + StaticSamplesBuffer::new(1.try_into().unwrap(), 1.try_into().unwrap(), &SAMPLES); + + let _ = input + .inspect_buffer::<{ SAMPLES.len() }, _>(|buffer| assert_eq!(*buffer, SAMPLES)) + .count(); + } + #[test] + fn source_does_not_truncate() { + const SAMPLES: [f32; 5] = [0.0, 1.0, 2.0, 3.0, 4.0]; + let input = + StaticSamplesBuffer::new(1.try_into().unwrap(), 1.try_into().unwrap(), &SAMPLES); + + let yielded = input + .inspect_buffer::<3, _>(|buffer| assert_eq!(buffer, &SAMPLES[..3])) + .count(); + assert_eq!(yielded, SAMPLES.len()) + } + } +} diff --git a/crates/livekit_client/src/livekit_client.rs b/crates/livekit_client/src/livekit_client.rs index c14a29bee31e227edbca5a9384dd01922e2085cb..50035d6eb5a9f11d22d44b223d52c1112f3b28d0 100644 --- a/crates/livekit_client/src/livekit_client.rs +++ b/crates/livekit_client/src/livekit_client.rs @@ -129,7 +129,7 @@ impl Room { cx: &mut App, ) -> Result { if AudioSettings::get_global(cx).rodio_audio { - info!("Using experimental.rodio_audio audio pipeline"); + info!("Using experimental.rodio_audio audio pipeline for output"); playback::play_remote_audio_track(&track.0, cx) } else { Ok(self.playback.play_remote_audio_track(&track.0)) diff --git a/crates/livekit_client/src/livekit_client/playback.rs b/crates/livekit_client/src/livekit_client/playback.rs index 8fa0e7929a15ece0116bc0d8ad5d71218e5d32fc..f28f363d48b1719212c703513fe7c6c332ee9f91 100644 --- a/crates/livekit_client/src/livekit_client/playback.rs +++ b/crates/livekit_client/src/livekit_client/playback.rs @@ -1,6 +1,6 @@ use anyhow::{Context as _, Result}; -use audio::AudioSettings; +use audio::{AudioSettings, CHANNEL_COUNT, SAMPLE_RATE}; use cpal::Sample; use cpal::traits::{DeviceTrait, StreamTrait as _}; use dasp_sample::ToSample; @@ -45,13 +45,6 @@ pub(crate) struct AudioStack { next_ssrc: AtomicI32, } -// NOTE: We use WebRTC's mixer which only supports -// 16kHz, 32kHz and 48kHz. As 48 is the most common "next step up" -// for audio output devices like speakers/bluetooth, we just hard-code -// this; and downsample when we need to. -const SAMPLE_RATE: NonZero = NonZero::new(48000).expect("not zero"); -const NUM_CHANNELS: NonZero = NonZero::new(2).expect("not zero"); - pub(crate) fn play_remote_audio_track( track: &livekit::track::RemoteAudioTrack, cx: &mut gpui::App, @@ -59,7 +52,6 @@ pub(crate) fn play_remote_audio_track( let stop_handle = Arc::new(AtomicBool::new(false)); let stop_handle_clone = stop_handle.clone(); let stream = source::LiveKitStream::new(cx.background_executor(), track) - .process_buffer(|| apm) .stoppable() .periodic_access(Duration::from_millis(50), move |s| { if stop_handle.load(Ordering::Relaxed) { @@ -101,7 +93,7 @@ impl AudioStack { let source = AudioMixerSource { ssrc: next_ssrc, sample_rate: SAMPLE_RATE.get(), - num_channels: NUM_CHANNELS.get() as u32, + num_channels: CHANNEL_COUNT.get() as u32, buffer: Arc::default(), }; self.mixer.lock().add_source(source.clone()); @@ -141,7 +133,7 @@ impl AudioStack { let apm = self.apm.clone(); let mixer = self.mixer.clone(); async move { - Self::play_output(apm, mixer, SAMPLE_RATE.get(), NUM_CHANNELS.get().into()) + Self::play_output(apm, mixer, SAMPLE_RATE.get(), CHANNEL_COUNT.get().into()) .await .log_err(); } @@ -158,7 +150,7 @@ impl AudioStack { // n.b. this struct's options are always ignored, noise cancellation is provided by apm. AudioSourceOptions::default(), SAMPLE_RATE.get(), - NUM_CHANNELS.get().into(), + CHANNEL_COUNT.get().into(), 10, ); @@ -178,16 +170,21 @@ impl AudioStack { } }); let rodio_pipeline = - AudioSettings::try_read_global(cx, |setting| setting.rodio_audio).unwrap_or(false); - let capture_task = self.executor.spawn(async move { - if rodio_pipeline { + AudioSettings::try_read_global(cx, |setting| setting.rodio_audio).unwrap_or_default(); + let capture_task = if rodio_pipeline { + let apm = cx + .try_read_global::(|audio, _| Arc::clone(&audio.echo_canceller)) + .unwrap(); // TODO fixme + self.executor.spawn(async move { info!("Using experimental.rodio_audio audio pipeline"); Self::capture_input_rodio(apm, frame_tx).await - } else { - Self::capture_input(apm, frame_tx, SAMPLE_RATE.get(), NUM_CHANNELS.get().into()) + }) + } else { + self.executor.spawn(async move { + Self::capture_input(apm, frame_tx, SAMPLE_RATE.get(), CHANNEL_COUNT.get().into()) .await - } - }); + }) + }; let on_drop = util::defer(|| { drop(transmit_task); @@ -277,10 +274,9 @@ impl AudioStack { apm: Arc>, frame_tx: UnboundedSender>, ) -> Result<()> { - use crate::livekit_client::playback::source::RodioExt; - const NUM_CHANNELS: usize = 1; + use audio::RodioExt; const LIVEKIT_BUFFER_SIZE: usize = - (SAMPLE_RATE.get() as usize / 100) * NUM_CHANNELS as usize; + (audio::SAMPLE_RATE.get() as usize / 100) * audio::CHANNEL_COUNT.get() as usize; let (stream_error_tx, stream_error_rx) = channel(); @@ -298,31 +294,29 @@ impl AudioStack { ]) .prefer_buffer_sizes(512..) .open_stream()?; - let mut stream = UniformSourceIterator::new( - stream, - NonZero::new(1).expect("1 is not zero"), - SAMPLE_RATE, - ) - .limit(LimitSettings::live_performance()) - .process_buffer::(|buffer| { - let mut int_buffer: [i16; _] = buffer.map(|s| s.to_sample()); - if let Err(e) = apm - .lock() - .process_stream( - &mut int_buffer, - SAMPLE_RATE.get() as i32, - NUM_CHANNELS as i32, - ) - .context("livekit audio processor error") - { - let _ = stream_error_tx.send(e); - } else { - for (sample, processed) in buffer.iter_mut().zip(&int_buffer) { - *sample = (*processed).to_sample_(); - } - } - }) - .automatic_gain_control(1.0, 4.0, 0.0, 5.0); + info!("Opened microphone: {:?}", stream.config()); + let mut stream = + UniformSourceIterator::new(stream, audio::CHANNEL_COUNT, audio::SAMPLE_RATE) + .limit(LimitSettings::live_performance()) + .process_buffer::(|buffer| { + let mut int_buffer: [i16; _] = buffer.map(|s| s.to_sample()); + if let Err(e) = apm + .lock() + .process_stream( + &mut int_buffer, + audio::SAMPLE_RATE.get() as i32, + audio::CHANNEL_COUNT.get() as i32, + ) + .context("livekit audio processor error") + { + let _ = stream_error_tx.send(e); + } else { + for (sample, processed) in buffer.iter_mut().zip(&int_buffer) { + *sample = (*processed).to_sample_(); + } + } + }) + .automatic_gain_control(1.0, 4.0, 0.0, 5.0); loop { let sampled: Vec<_> = stream @@ -342,8 +336,9 @@ impl AudioStack { frame_tx .unbounded_send(AudioFrame { sample_rate: SAMPLE_RATE.get(), - num_channels: NUM_CHANNELS as u32, - samples_per_channel: sampled.len() as u32 / NUM_CHANNELS as u32, + num_channels: audio::CHANNEL_COUNT.get() as u32, + samples_per_channel: sampled.len() as u32 + / audio::CHANNEL_COUNT.get() as u32, data: Cow::Owned(sampled), }) .context("Failed to send audio frame")? @@ -445,8 +440,6 @@ impl AudioStack { } } -use crate::livekit_client::playback::source::RodioExt; - use super::LocalVideoTrack; pub enum AudioStream { diff --git a/crates/livekit_client/src/livekit_client/playback/source.rs b/crates/livekit_client/src/livekit_client/playback/source.rs index 62d949f95d215f15926be3c98eab1d4e9dd2d09e..51ceb081d78ecd47f898e025ee8e0cb8235e2f13 100644 --- a/crates/livekit_client/src/livekit_client/playback/source.rs +++ b/crates/livekit_client/src/livekit_client/playback/source.rs @@ -5,7 +5,7 @@ use libwebrtc::{audio_stream::native::NativeAudioStream, prelude::AudioFrame}; use livekit::track::RemoteAudioTrack; use rodio::{Source, buffer::SamplesBuffer, conversions::SampleTypeConverter}; -use crate::livekit_client::playback::{NUM_CHANNELS, SAMPLE_RATE}; +use crate::livekit_client::playback::{CHANNEL_COUNT, SAMPLE_RATE}; fn frame_to_samplesbuffer(frame: AudioFrame) -> SamplesBuffer { let samples = frame.data.iter().copied(); @@ -29,7 +29,7 @@ impl LiveKitStream { let mut stream = NativeAudioStream::new( track.rtc_track(), SAMPLE_RATE.get() as i32, - NUM_CHANNELS.get().into(), + CHANNEL_COUNT.get().into(), ); let (queue_input, queue_output) = rodio::queue::queue(true); // spawn rtc stream