Adds microphone input through Rodio

David Kleingeld created

This adds an extension trait to enable us apply the apm echo canceller to the microphone input.

Change summary

Cargo.lock                                                  | 99 ++++++
Cargo.toml                                                  |  2 
crates/livekit_client/Cargo.toml                            |  3 
crates/livekit_client/src/lib.rs                            | 66 ++--
crates/livekit_client/src/livekit_client/playback.rs        | 61 ++++
crates/livekit_client/src/livekit_client/playback/source.rs | 86 ++++++
crates/livekit_client/src/record.rs                         |  7 
7 files changed, 285 insertions(+), 39 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -9635,6 +9635,7 @@ dependencies = [
  "core-video",
  "coreaudio-rs 0.12.1",
  "cpal",
+ "dasp_sample",
  "futures 0.3.31",
  "gpui",
  "gpui_tokio",
@@ -13853,14 +13854,15 @@ dependencies = [
 [[package]]
 name = "rodio"
 version = "0.21.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "e40ecf59e742e03336be6a3d53755e789fd05a059fa22dfa0ed624722319e183"
+source = "git+https://github.com/RustAudio/rodio?branch=microphone#bb560f30b17d330459b81afc918f2a4a123c41aa"
 dependencies = [
  "cpal",
  "dasp_sample",
  "hound",
  "num-rational",
+ "rtrb",
  "symphonia",
+ "thiserror 2.0.12",
  "tracing",
 ]
 
@@ -13935,6 +13937,12 @@ dependencies = [
  "zeroize",
 ]
 
+[[package]]
+name = "rtrb"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ad8388ea1a9e0ea807e442e8263a699e7edcb320ecbcd21b4fa8ff859acce3ba"
+
 [[package]]
 name = "rules_library"
 version = "0.1.0"
@@ -15896,12 +15904,53 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "815c942ae7ee74737bb00f965fa5b5a2ac2ce7b6c01c0cc169bbeaf7abd5f5a9"
 dependencies = [
  "lazy_static",
+ "symphonia-bundle-flac",
+ "symphonia-bundle-mp3",
+ "symphonia-codec-aac",
  "symphonia-codec-pcm",
+ "symphonia-codec-vorbis",
  "symphonia-core",
+ "symphonia-format-isomp4",
+ "symphonia-format-ogg",
  "symphonia-format-riff",
  "symphonia-metadata",
 ]
 
+[[package]]
+name = "symphonia-bundle-flac"
+version = "0.5.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "72e34f34298a7308d4397a6c7fbf5b84c5d491231ce3dd379707ba673ab3bd97"
+dependencies = [
+ "log",
+ "symphonia-core",
+ "symphonia-metadata",
+ "symphonia-utils-xiph",
+]
+
+[[package]]
+name = "symphonia-bundle-mp3"
+version = "0.5.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c01c2aae70f0f1fb096b6f0ff112a930b1fb3626178fba3ae68b09dce71706d4"
+dependencies = [
+ "lazy_static",
+ "log",
+ "symphonia-core",
+ "symphonia-metadata",
+]
+
+[[package]]
+name = "symphonia-codec-aac"
+version = "0.5.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cdbf25b545ad0d3ee3e891ea643ad115aff4ca92f6aec472086b957a58522f70"
+dependencies = [
+ "lazy_static",
+ "log",
+ "symphonia-core",
+]
+
 [[package]]
 name = "symphonia-codec-pcm"
 version = "0.5.4"
@@ -15912,6 +15961,17 @@ dependencies = [
  "symphonia-core",
 ]
 
+[[package]]
+name = "symphonia-codec-vorbis"
+version = "0.5.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a98765fb46a0a6732b007f7e2870c2129b6f78d87db7987e6533c8f164a9f30"
+dependencies = [
+ "log",
+ "symphonia-core",
+ "symphonia-utils-xiph",
+]
+
 [[package]]
 name = "symphonia-core"
 version = "0.5.4"
@@ -15925,6 +15985,31 @@ dependencies = [
  "log",
 ]
 
+[[package]]
+name = "symphonia-format-isomp4"
+version = "0.5.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "abfdf178d697e50ce1e5d9b982ba1b94c47218e03ec35022d9f0e071a16dc844"
+dependencies = [
+ "encoding_rs",
+ "log",
+ "symphonia-core",
+ "symphonia-metadata",
+ "symphonia-utils-xiph",
+]
+
+[[package]]
+name = "symphonia-format-ogg"
+version = "0.5.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ada3505789516bcf00fc1157c67729eded428b455c27ca370e41f4d785bfa931"
+dependencies = [
+ "log",
+ "symphonia-core",
+ "symphonia-metadata",
+ "symphonia-utils-xiph",
+]
+
 [[package]]
 name = "symphonia-format-riff"
 version = "0.5.4"
@@ -15949,6 +16034,16 @@ dependencies = [
  "symphonia-core",
 ]
 
+[[package]]
+name = "symphonia-utils-xiph"
+version = "0.5.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "484472580fa49991afda5f6550ece662237b00c6f562c7d9638d1b086ed010fe"
+dependencies = [
+ "symphonia-core",
+ "symphonia-metadata",
+]
+
 [[package]]
 name = "syn"
 version = "1.0.109"

Cargo.toml 🔗

@@ -367,7 +367,7 @@ remote_server = { path = "crates/remote_server" }
 repl = { path = "crates/repl" }
 reqwest_client = { path = "crates/reqwest_client" }
 rich_text = { path = "crates/rich_text" }
-rodio = { version = "0.21.1", default-features = false }
+rodio = { git = "https://github.com/RustAudio/rodio", branch = "microphone"}
 rope = { path = "crates/rope" }
 rpc = { path = "crates/rpc" }
 rules_library = { path = "crates/rules_library" }

crates/livekit_client/Cargo.toml 🔗

@@ -41,7 +41,8 @@ tokio-tungstenite.workspace = true
 util.workspace = true
 workspace-hack.workspace = true
 
-rodio = { workspace = true, features = ["wav_output"] }
+rodio = { workspace = true, features = ["wav_output", "recording"] }
+dasp_sample = "0.11"
 
 [target.'cfg(not(any(all(target_os = "windows", target_env = "gnu"), target_os = "freebsd")))'.dependencies]
 libwebrtc = { rev = "5f04705ac3f356350ae31534ffbc476abc9ea83d", git = "https://github.com/zed-industries/livekit-rust-sdks" }

crates/livekit_client/src/lib.rs 🔗

@@ -9,19 +9,19 @@ use rodio::DeviceTrait as _;
 mod record;
 pub use record::CaptureInput;
 
-#[cfg(not(any(
-    test,
-    feature = "test-support",
-    all(target_os = "windows", target_env = "gnu"),
-    target_os = "freebsd"
-)))]
+// #[cfg(not(any(
+//     test,
+//     feature = "test-support",
+//     all(target_os = "windows", target_env = "gnu"),
+//     target_os = "freebsd"
+// )))]
 mod livekit_client;
-#[cfg(not(any(
-    test,
-    feature = "test-support",
-    all(target_os = "windows", target_env = "gnu"),
-    target_os = "freebsd"
-)))]
+// #[cfg(not(any(
+//     test,
+//     feature = "test-support",
+//     all(target_os = "windows", target_env = "gnu"),
+//     target_os = "freebsd"
+// )))]
 pub use livekit_client::*;
 
 // If you need proper LSP in livekit_client you've got to comment
@@ -29,27 +29,27 @@ pub use livekit_client::*;
 // - the mods: mock_client & test and their conditional blocks
 // - the pub use mock_client::* and their conditional blocks
 
-#[cfg(any(
-    test,
-    feature = "test-support",
-    all(target_os = "windows", target_env = "gnu"),
-    target_os = "freebsd"
-))]
-mod mock_client;
-#[cfg(any(
-    test,
-    feature = "test-support",
-    all(target_os = "windows", target_env = "gnu"),
-    target_os = "freebsd"
-))]
-pub mod test;
-#[cfg(any(
-    test,
-    feature = "test-support",
-    all(target_os = "windows", target_env = "gnu"),
-    target_os = "freebsd"
-))]
-pub use mock_client::*;
+// #[cfg(any(
+//     test,
+//     feature = "test-support",
+//     all(target_os = "windows", target_env = "gnu"),
+//     target_os = "freebsd"
+// ))]
+// mod mock_client;
+// #[cfg(any(
+//     test,
+//     feature = "test-support",
+//     all(target_os = "windows", target_env = "gnu"),
+//     target_os = "freebsd"
+// ))]
+// pub mod test;
+// #[cfg(any(
+//     test,
+//     feature = "test-support",
+//     all(target_os = "windows", target_env = "gnu"),
+//     target_os = "freebsd"
+// ))]
+// pub use mock_client::*;
 
 #[derive(Debug, Clone)]
 pub enum Participant {

crates/livekit_client/src/livekit_client/playback.rs 🔗

@@ -1,6 +1,8 @@
 use anyhow::{Context as _, Result};
 
+use cpal::Sample;
 use cpal::traits::{DeviceTrait, StreamTrait as _};
+use dasp_sample::ToSample;
 use futures::channel::mpsc::UnboundedSender;
 use futures::{Stream, StreamExt as _};
 use gpui::{
@@ -19,7 +21,9 @@ use livekit::webrtc::{
 };
 use parking_lot::Mutex;
 use rodio::Source;
+use rodio::source::{LimitSettings, UniformSourceIterator};
 use std::cell::RefCell;
+use std::num::NonZero;
 use std::sync::Weak;
 use std::sync::atomic::{AtomicBool, AtomicI32, Ordering};
 use std::time::Duration;
@@ -254,6 +258,63 @@ impl AudioStack {
         }
     }
 
+    async fn capture_input_rodio(
+        apm: Arc<Mutex<apm::AudioProcessingModule>>,
+        frame_tx: UnboundedSender<AudioFrame<'static>>,
+        sample_rate: u32,
+        num_channels: u32,
+    ) -> Result<()> {
+        use crate::livekit_client::playback::source::RodioExt;
+        const NUM_CHANNELS: usize = 1;
+        const LIVEKIT_BUFFER_SIZE: usize = (SAMPLE_RATE as usize / 100) * NUM_CHANNELS as usize;
+
+        thread::spawn(move || {
+            let stream = rodio::microphone::MicrophoneBuilder::new()
+                .default_device()?
+                .default_config()?
+                .open_stream()?;
+            let mut stream = UniformSourceIterator::new(
+                stream,
+                NonZero::new(1).expect("1 is not zero"),
+                NonZero::new(SAMPLE_RATE).expect("constant is not zero"),
+            )
+            .limit(LimitSettings::live_performance())
+            .process_buffer::<LIVEKIT_BUFFER_SIZE, _>(|buffer| {
+                let mut int_buffer: [i16; _] = buffer.map(|s| s.to_sample());
+                apm.lock()
+                    .process_stream(&mut int_buffer, sample_rate as i32, num_channels as i32)
+                    .unwrap(); // TODO dvdsk fix this
+                for (sample, processed) in buffer.iter_mut().zip(&int_buffer) {
+                    *sample = (*processed).to_sample_();
+                }
+            })
+            .automatic_gain_control(1.0, 4.0, 0.0, 5.0);
+
+            loop {
+                let sampled = stream
+                    .by_ref()
+                    .take(LIVEKIT_BUFFER_SIZE)
+                    .map(|s| s.to_sample())
+                    .collect();
+
+                if frame_tx
+                    .unbounded_send(AudioFrame {
+                        data: Cow::Owned(sampled),
+                        sample_rate,
+                        num_channels,
+                        samples_per_channel: sample_rate / 100,
+                    })
+                    .is_err()
+                {
+                    break;
+                }
+            }
+            Ok::<(), anyhow::Error>(())
+        });
+
+        Ok(())
+    }
+
     async fn capture_input(
         apm: Arc<Mutex<apm::AudioProcessingModule>>,
         frame_tx: UnboundedSender<AudioFrame<'static>>,

crates/livekit_client/src/livekit_client/playback/source.rs 🔗

@@ -1,3 +1,5 @@
+use std::num::NonZero;
+
 use futures::StreamExt;
 use libwebrtc::{audio_stream::native::NativeAudioStream, prelude::AudioFrame};
 use livekit::track::RemoteAudioTrack;
@@ -9,7 +11,11 @@ fn frame_to_samplesbuffer(frame: AudioFrame) -> SamplesBuffer {
     let samples = frame.data.iter().copied();
     let samples = SampleTypeConverter::<_, _>::new(samples);
     let samples: Vec<f32> = samples.collect();
-    SamplesBuffer::new(frame.num_channels as u16, frame.sample_rate, samples)
+    SamplesBuffer::new(
+        NonZero::new(frame.num_channels as u16).expect("audio frame channels is nonzero"),
+        NonZero::new(frame.sample_rate).expect("audio frame sample rate is nonzero"),
+        samples,
+    )
 }
 
 pub struct LiveKitStream {
@@ -65,3 +71,81 @@ impl Source for LiveKitStream {
         self.inner.total_duration()
     }
 }
+
+pub trait RodioExt: Source + Sized {
+    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
+    where
+        F: FnMut(&mut [rodio::Sample; N]);
+}
+
+impl<S: Source> RodioExt for S {
+    fn process_buffer<const N: usize, F>(self, callback: F) -> ProcessBuffer<N, Self, F>
+    where
+        F: FnMut(&mut [rodio::Sample; N]),
+    {
+        ProcessBuffer {
+            inner: self,
+            callback,
+            buffer: [0.0; N],
+            next: N,
+        }
+    }
+}
+
+pub struct ProcessBuffer<const N: usize, S, F>
+where
+    S: Source + Sized,
+    F: FnMut(&mut [rodio::Sample; N]),
+{
+    inner: S,
+    callback: F,
+    buffer: [rodio::Sample; N],
+    next: usize,
+}
+
+impl<const N: usize, S, F> Iterator for ProcessBuffer<S, F, N>
+where
+    S: Source + Sized,
+    F: FnMut(&mut [rodio::Sample; N]),
+{
+    type Item = rodio::Sample;
+
+    fn next(&mut self) -> Option<Self::Item> {
+        self.next += 1;
+        if self.next < self.buffer.len() {
+            let sample = self.buffer[self.next];
+            return Some(sample);
+        }
+
+        for sample in &mut self.buffer {
+            *sample = self.inner.next()?
+        }
+        (self.callback)(&mut self.buffer);
+
+        self.next = 0;
+        Some(self.buffer[0])
+    }
+}
+
+// TODO dvdsk this should be a spanless Source
+impl<const N: usize, S, F> Source for ProcessBuffer<N, S, F>
+where
+    S: Source + Sized,
+    F: FnMut(&mut [rodio::Sample; N]),
+{
+    fn current_span_len(&self) -> Option<usize> {
+        None
+    }
+
+    fn channels(&self) -> rodio::ChannelCount {
+        self.inner.channels()
+    }
+
+    fn sample_rate(&self) -> rodio::SampleRate {
+        self.inner.sample_rate()
+    }
+
+    fn total_duration(&self) -> Option<std::time::Duration> {
+        self.inner.total_duration()
+    }
+}

crates/livekit_client/src/record.rs 🔗

@@ -1,5 +1,6 @@
 use std::{
     env,
+    num::NonZero,
     path::{Path, PathBuf},
     sync::{Arc, Mutex},
     time::Duration,
@@ -83,7 +84,11 @@ fn write_out(
             .expect("Stream has ended, callback cant hold the lock"),
     );
     let samples: Vec<f32> = SampleTypeConverter::<_, f32>::new(samples.into_iter()).collect();
-    let mut samples = SamplesBuffer::new(config.channels(), config.sample_rate().0, samples);
+    let mut samples = SamplesBuffer::new(
+        NonZero::new(config.channels()).expect("config channel is never zero"),
+        NonZero::new(config.sample_rate().0).expect("config sample_rate is never zero"),
+        samples,
+    );
     match rodio::output_to_wav(&mut samples, path) {
         Ok(_) => Ok(()),
         Err(e) => Err(anyhow::anyhow!("Failed to write wav file: {}", e)),