playback.rs

  1use anyhow::{Context as _, Result};
  2
  3use cpal::traits::{DeviceTrait, StreamTrait as _};
  4use futures::channel::mpsc::UnboundedSender;
  5use futures::{Stream, StreamExt as _};
  6use gpui::{
  7    BackgroundExecutor, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, Task,
  8};
  9use libwebrtc::native::{apm, audio_mixer, audio_resampler};
 10use livekit::track;
 11
 12use livekit::webrtc::{
 13    audio_frame::AudioFrame,
 14    audio_source::{AudioSourceOptions, RtcAudioSource, native::NativeAudioSource},
 15    audio_stream::native::NativeAudioStream,
 16    video_frame::{VideoBuffer, VideoFrame, VideoRotation},
 17    video_source::{RtcVideoSource, VideoResolution, native::NativeVideoSource},
 18    video_stream::native::NativeVideoStream,
 19};
 20use parking_lot::Mutex;
 21use std::cell::RefCell;
 22use std::sync::Weak;
 23use std::sync::atomic::{self, AtomicI32};
 24use std::time::Duration;
 25use std::{borrow::Cow, collections::VecDeque, sync::Arc, thread};
 26use util::{ResultExt as _, maybe};
 27
 28pub(crate) struct AudioStack {
 29    executor: BackgroundExecutor,
 30    apm: Arc<Mutex<apm::AudioProcessingModule>>,
 31    mixer: Arc<Mutex<audio_mixer::AudioMixer>>,
 32    _output_task: RefCell<Weak<Task<()>>>,
 33    next_ssrc: AtomicI32,
 34}
 35
 36// NOTE: We use WebRTC's mixer which only supports
 37// 16kHz, 32kHz and 48kHz. As 48 is the most common "next step up"
 38// for audio output devices like speakers/bluetooth, we just hard-code
 39// this; and downsample when we need to.
 40const SAMPLE_RATE: u32 = 48000;
 41const NUM_CHANNELS: u32 = 2;
 42
 43impl AudioStack {
 44    pub(crate) fn new(executor: BackgroundExecutor) -> Self {
 45        let apm = Arc::new(Mutex::new(apm::AudioProcessingModule::new(
 46            true, true, true, true,
 47        )));
 48        let mixer = Arc::new(Mutex::new(audio_mixer::AudioMixer::new()));
 49        Self {
 50            executor,
 51            apm,
 52            mixer,
 53            _output_task: RefCell::new(Weak::new()),
 54            next_ssrc: AtomicI32::new(1),
 55        }
 56    }
 57
 58    pub(crate) fn play_remote_audio_track(
 59        &self,
 60        track: &livekit::track::RemoteAudioTrack,
 61    ) -> AudioStream {
 62        let output_task = self.start_output();
 63
 64        let next_ssrc = self.next_ssrc.fetch_add(1, atomic::Ordering::Relaxed);
 65        let source = AudioMixerSource {
 66            ssrc: next_ssrc,
 67            sample_rate: SAMPLE_RATE,
 68            num_channels: NUM_CHANNELS,
 69            buffer: Arc::default(),
 70        };
 71        self.mixer.lock().add_source(source.clone());
 72
 73        let mut stream = NativeAudioStream::new(
 74            track.rtc_track(),
 75            source.sample_rate as i32,
 76            source.num_channels as i32,
 77        );
 78
 79        let receive_task = self.executor.spawn({
 80            let source = source.clone();
 81            async move {
 82                while let Some(frame) = stream.next().await {
 83                    source.receive(frame);
 84                }
 85            }
 86        });
 87
 88        let mixer = self.mixer.clone();
 89        let on_drop = util::defer(move || {
 90            mixer.lock().remove_source(source.ssrc);
 91            drop(receive_task);
 92            drop(output_task);
 93        });
 94
 95        AudioStream::Output {
 96            _drop: Box::new(on_drop),
 97        }
 98    }
 99
100    pub(crate) fn capture_local_microphone_track(
101        &self,
102    ) -> Result<(crate::LocalAudioTrack, AudioStream)> {
103        let source = NativeAudioSource::new(
104            // n.b. this struct's options are always ignored, noise cancellation is provided by apm.
105            AudioSourceOptions::default(),
106            SAMPLE_RATE,
107            NUM_CHANNELS,
108            10,
109        );
110
111        let track = track::LocalAudioTrack::create_audio_track(
112            "microphone",
113            RtcAudioSource::Native(source.clone()),
114        );
115
116        let apm = self.apm.clone();
117
118        let (frame_tx, mut frame_rx) = futures::channel::mpsc::unbounded();
119        let transmit_task = self.executor.spawn({
120            let source = source.clone();
121            async move {
122                while let Some(frame) = frame_rx.next().await {
123                    source.capture_frame(&frame).await.log_err();
124                }
125            }
126        });
127        let capture_task = self.executor.spawn(async move {
128            Self::capture_input(apm, frame_tx, SAMPLE_RATE, NUM_CHANNELS).await
129        });
130
131        let on_drop = util::defer(|| {
132            drop(transmit_task);
133            drop(capture_task);
134        });
135        return Ok((
136            super::LocalAudioTrack(track),
137            AudioStream::Output {
138                _drop: Box::new(on_drop),
139            },
140        ));
141    }
142
143    fn start_output(&self) -> Arc<Task<()>> {
144        if let Some(task) = self._output_task.borrow().upgrade() {
145            return task;
146        }
147        let task = Arc::new(self.executor.spawn({
148            let apm = self.apm.clone();
149            let mixer = self.mixer.clone();
150            async move {
151                Self::play_output(apm, mixer, SAMPLE_RATE, NUM_CHANNELS)
152                    .await
153                    .log_err();
154            }
155        }));
156        *self._output_task.borrow_mut() = Arc::downgrade(&task);
157        task
158    }
159
160    async fn play_output(
161        apm: Arc<Mutex<apm::AudioProcessingModule>>,
162        mixer: Arc<Mutex<audio_mixer::AudioMixer>>,
163        sample_rate: u32,
164        num_channels: u32,
165    ) -> Result<()> {
166        loop {
167            let mut device_change_listener = DeviceChangeListener::new(false)?;
168            let (output_device, output_config) = crate::default_device(false)?;
169            let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>();
170            let mixer = mixer.clone();
171            let apm = apm.clone();
172            let mut resampler = audio_resampler::AudioResampler::default();
173            let mut buf = Vec::new();
174
175            thread::spawn(move || {
176                let output_stream = output_device.build_output_stream(
177                    &output_config.config(),
178                    {
179                        move |mut data, _info| {
180                            while data.len() > 0 {
181                                if data.len() <= buf.len() {
182                                    let rest = buf.split_off(data.len());
183                                    data.copy_from_slice(&buf);
184                                    buf = rest;
185                                    return;
186                                }
187                                if buf.len() > 0 {
188                                    let (prefix, suffix) = data.split_at_mut(buf.len());
189                                    prefix.copy_from_slice(&buf);
190                                    data = suffix;
191                                }
192
193                                let mut mixer = mixer.lock();
194                                let mixed = mixer.mix(output_config.channels() as usize);
195                                let sampled = resampler.remix_and_resample(
196                                    mixed,
197                                    sample_rate / 100,
198                                    num_channels,
199                                    sample_rate,
200                                    output_config.channels() as u32,
201                                    output_config.sample_rate().0,
202                                );
203                                buf = sampled.to_vec();
204                                apm.lock()
205                                    .process_reverse_stream(
206                                        &mut buf,
207                                        output_config.sample_rate().0 as i32,
208                                        output_config.channels() as i32,
209                                    )
210                                    .ok();
211                            }
212                        }
213                    },
214                    |error| log::error!("error playing audio track: {:?}", error),
215                    Some(Duration::from_millis(100)),
216                );
217
218                let Some(output_stream) = output_stream.log_err() else {
219                    return;
220                };
221
222                output_stream.play().log_err();
223                // Block forever to keep the output stream alive
224                end_on_drop_rx.recv().ok();
225            });
226
227            device_change_listener.next().await;
228            drop(end_on_drop_tx)
229        }
230    }
231
232    async fn capture_input(
233        apm: Arc<Mutex<apm::AudioProcessingModule>>,
234        frame_tx: UnboundedSender<AudioFrame<'static>>,
235        sample_rate: u32,
236        num_channels: u32,
237    ) -> Result<()> {
238        loop {
239            let mut device_change_listener = DeviceChangeListener::new(true)?;
240            let (device, config) = crate::default_device(true)?;
241            let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>();
242            let apm = apm.clone();
243            let frame_tx = frame_tx.clone();
244            let mut resampler = audio_resampler::AudioResampler::default();
245
246            thread::spawn(move || {
247                maybe!({
248                    if let Some(name) = device.name().ok() {
249                        log::info!("Using microphone: {}", name)
250                    } else {
251                        log::info!("Using microphone: <unknown>");
252                    }
253
254                    let ten_ms_buffer_size =
255                        (config.channels() as u32 * config.sample_rate().0 / 100) as usize;
256                    let mut buf: Vec<i16> = Vec::with_capacity(ten_ms_buffer_size);
257
258                    let stream = device
259                        .build_input_stream_raw(
260                            &config.config(),
261                            config.sample_format(),
262                            move |data, _: &_| {
263                                let data =
264                                    crate::get_sample_data(config.sample_format(), data).log_err();
265                                let Some(data) = data else {
266                                    return;
267                                };
268                                let mut data = data.as_slice();
269
270                                while data.len() > 0 {
271                                    let remainder = (buf.capacity() - buf.len()).min(data.len());
272                                    buf.extend_from_slice(&data[..remainder]);
273                                    data = &data[remainder..];
274
275                                    if buf.capacity() == buf.len() {
276                                        let mut sampled = resampler
277                                            .remix_and_resample(
278                                                buf.as_slice(),
279                                                config.sample_rate().0 / 100,
280                                                config.channels() as u32,
281                                                config.sample_rate().0,
282                                                num_channels,
283                                                sample_rate,
284                                            )
285                                            .to_owned();
286                                        apm.lock()
287                                            .process_stream(
288                                                &mut sampled,
289                                                sample_rate as i32,
290                                                num_channels as i32,
291                                            )
292                                            .log_err();
293                                        buf.clear();
294                                        frame_tx
295                                            .unbounded_send(AudioFrame {
296                                                data: Cow::Owned(sampled),
297                                                sample_rate,
298                                                num_channels,
299                                                samples_per_channel: sample_rate / 100,
300                                            })
301                                            .ok();
302                                    }
303                                }
304                            },
305                            |err| log::error!("error capturing audio track: {:?}", err),
306                            Some(Duration::from_millis(100)),
307                        )
308                        .context("failed to build input stream")?;
309
310                    stream.play()?;
311                    // Keep the thread alive and holding onto the `stream`
312                    end_on_drop_rx.recv().ok();
313                    anyhow::Ok(Some(()))
314                })
315                .log_err();
316            });
317
318            device_change_listener.next().await;
319            drop(end_on_drop_tx)
320        }
321    }
322}
323
324use super::LocalVideoTrack;
325
326pub enum AudioStream {
327    Input { _task: Task<()> },
328    Output { _drop: Box<dyn std::any::Any> },
329}
330
331pub(crate) async fn capture_local_video_track(
332    capture_source: &dyn ScreenCaptureSource,
333    cx: &mut gpui::AsyncApp,
334) -> Result<(crate::LocalVideoTrack, Box<dyn ScreenCaptureStream>)> {
335    let metadata = capture_source.metadata()?;
336    let track_source = gpui_tokio::Tokio::spawn(cx, async move {
337        NativeVideoSource::new(VideoResolution {
338            width: metadata.resolution.width.0 as u32,
339            height: metadata.resolution.height.0 as u32,
340        })
341    })?
342    .await?;
343
344    let capture_stream = capture_source
345        .stream(cx.foreground_executor(), {
346            let track_source = track_source.clone();
347            Box::new(move |frame| {
348                if let Some(buffer) = video_frame_buffer_to_webrtc(frame) {
349                    track_source.capture_frame(&VideoFrame {
350                        rotation: VideoRotation::VideoRotation0,
351                        timestamp_us: 0,
352                        buffer,
353                    });
354                }
355            })
356        })
357        .await??;
358
359    Ok((
360        LocalVideoTrack(track::LocalVideoTrack::create_video_track(
361            "screen share",
362            RtcVideoSource::Native(track_source),
363        )),
364        capture_stream,
365    ))
366}
367
368#[derive(Clone)]
369struct AudioMixerSource {
370    ssrc: i32,
371    sample_rate: u32,
372    num_channels: u32,
373    buffer: Arc<Mutex<VecDeque<Vec<i16>>>>,
374}
375
376impl AudioMixerSource {
377    fn receive(&self, frame: AudioFrame) {
378        assert_eq!(
379            frame.data.len() as u32,
380            self.sample_rate * self.num_channels / 100
381        );
382
383        let mut buffer = self.buffer.lock();
384        buffer.push_back(frame.data.to_vec());
385        while buffer.len() > 10 {
386            buffer.pop_front();
387        }
388    }
389}
390
391impl libwebrtc::native::audio_mixer::AudioMixerSource for AudioMixerSource {
392    fn ssrc(&self) -> i32 {
393        self.ssrc
394    }
395
396    fn preferred_sample_rate(&self) -> u32 {
397        self.sample_rate
398    }
399
400    fn get_audio_frame_with_info<'a>(&self, target_sample_rate: u32) -> Option<AudioFrame<'_>> {
401        assert_eq!(self.sample_rate, target_sample_rate);
402        let buf = self.buffer.lock().pop_front()?;
403        Some(AudioFrame {
404            data: Cow::Owned(buf),
405            sample_rate: self.sample_rate,
406            num_channels: self.num_channels,
407            samples_per_channel: self.sample_rate / 100,
408        })
409    }
410}
411
412pub fn play_remote_video_track(
413    track: &crate::RemoteVideoTrack,
414) -> impl Stream<Item = RemoteVideoFrame> + use<> {
415    #[cfg(target_os = "macos")]
416    {
417        let mut pool = None;
418        let most_recent_frame_size = (0, 0);
419        NativeVideoStream::new(track.0.rtc_track()).filter_map(move |frame| {
420            if pool == None
421                || most_recent_frame_size != (frame.buffer.width(), frame.buffer.height())
422            {
423                pool = create_buffer_pool(frame.buffer.width(), frame.buffer.height()).log_err();
424            }
425            let pool = pool.clone();
426            async move {
427                if frame.buffer.width() < 10 && frame.buffer.height() < 10 {
428                    // when the remote stops sharing, we get an 8x8 black image.
429                    // In a lil bit, the unpublish will come through and close the view,
430                    // but until then, don't flash black.
431                    return None;
432                }
433
434                video_frame_buffer_from_webrtc(pool?, frame.buffer)
435            }
436        })
437    }
438    #[cfg(not(target_os = "macos"))]
439    {
440        NativeVideoStream::new(track.0.rtc_track())
441            .filter_map(|frame| async move { video_frame_buffer_from_webrtc(frame.buffer) })
442    }
443}
444
445#[cfg(target_os = "macos")]
446fn create_buffer_pool(
447    width: u32,
448    height: u32,
449) -> Result<core_video::pixel_buffer_pool::CVPixelBufferPool> {
450    use core_foundation::{base::TCFType, number::CFNumber, string::CFString};
451    use core_video::pixel_buffer;
452    use core_video::{
453        pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange,
454        pixel_buffer_io_surface::kCVPixelBufferIOSurfaceCoreAnimationCompatibilityKey,
455        pixel_buffer_pool::{self},
456    };
457
458    let width_key: CFString =
459        unsafe { CFString::wrap_under_get_rule(pixel_buffer::kCVPixelBufferWidthKey) };
460    let height_key: CFString =
461        unsafe { CFString::wrap_under_get_rule(pixel_buffer::kCVPixelBufferHeightKey) };
462    let animation_key: CFString = unsafe {
463        CFString::wrap_under_get_rule(kCVPixelBufferIOSurfaceCoreAnimationCompatibilityKey)
464    };
465    let format_key: CFString =
466        unsafe { CFString::wrap_under_get_rule(pixel_buffer::kCVPixelBufferPixelFormatTypeKey) };
467
468    let yes: CFNumber = 1.into();
469    let width: CFNumber = (width as i32).into();
470    let height: CFNumber = (height as i32).into();
471    let format: CFNumber = (kCVPixelFormatType_420YpCbCr8BiPlanarFullRange as i64).into();
472
473    let buffer_attributes = core_foundation::dictionary::CFDictionary::from_CFType_pairs(&[
474        (width_key, width.into_CFType()),
475        (height_key, height.into_CFType()),
476        (animation_key, yes.into_CFType()),
477        (format_key, format.into_CFType()),
478    ]);
479
480    pixel_buffer_pool::CVPixelBufferPool::new(None, Some(&buffer_attributes)).map_err(|cv_return| {
481        anyhow::anyhow!("failed to create pixel buffer pool: CVReturn({cv_return})",)
482    })
483}
484
485#[cfg(target_os = "macos")]
486pub type RemoteVideoFrame = core_video::pixel_buffer::CVPixelBuffer;
487
488#[cfg(target_os = "macos")]
489fn video_frame_buffer_from_webrtc(
490    pool: core_video::pixel_buffer_pool::CVPixelBufferPool,
491    buffer: Box<dyn VideoBuffer>,
492) -> Option<RemoteVideoFrame> {
493    use core_foundation::base::TCFType;
494    use core_video::{pixel_buffer::CVPixelBuffer, r#return::kCVReturnSuccess};
495    use livekit::webrtc::native::yuv_helper::i420_to_nv12;
496
497    if let Some(native) = buffer.as_native() {
498        let pixel_buffer = native.get_cv_pixel_buffer();
499        if pixel_buffer.is_null() {
500            return None;
501        }
502        return unsafe { Some(CVPixelBuffer::wrap_under_get_rule(pixel_buffer as _)) };
503    }
504
505    let i420_buffer = buffer.as_i420()?;
506    let pixel_buffer = pool.create_pixel_buffer().log_err()?;
507
508    let image_buffer = unsafe {
509        if pixel_buffer.lock_base_address(0) != kCVReturnSuccess {
510            return None;
511        }
512
513        let dst_y = pixel_buffer.get_base_address_of_plane(0);
514        let dst_y_stride = pixel_buffer.get_bytes_per_row_of_plane(0);
515        let dst_y_len = pixel_buffer.get_height_of_plane(0) * dst_y_stride;
516        let dst_uv = pixel_buffer.get_base_address_of_plane(1);
517        let dst_uv_stride = pixel_buffer.get_bytes_per_row_of_plane(1);
518        let dst_uv_len = pixel_buffer.get_height_of_plane(1) * dst_uv_stride;
519        let width = pixel_buffer.get_width();
520        let height = pixel_buffer.get_height();
521        let dst_y_buffer = std::slice::from_raw_parts_mut(dst_y as *mut u8, dst_y_len);
522        let dst_uv_buffer = std::slice::from_raw_parts_mut(dst_uv as *mut u8, dst_uv_len);
523
524        let (stride_y, stride_u, stride_v) = i420_buffer.strides();
525        let (src_y, src_u, src_v) = i420_buffer.data();
526        i420_to_nv12(
527            src_y,
528            stride_y,
529            src_u,
530            stride_u,
531            src_v,
532            stride_v,
533            dst_y_buffer,
534            dst_y_stride as u32,
535            dst_uv_buffer,
536            dst_uv_stride as u32,
537            width as i32,
538            height as i32,
539        );
540
541        if pixel_buffer.unlock_base_address(0) != kCVReturnSuccess {
542            return None;
543        }
544
545        pixel_buffer
546    };
547
548    Some(image_buffer)
549}
550
551#[cfg(not(target_os = "macos"))]
552pub type RemoteVideoFrame = Arc<gpui::RenderImage>;
553
554#[cfg(not(target_os = "macos"))]
555fn video_frame_buffer_from_webrtc(buffer: Box<dyn VideoBuffer>) -> Option<RemoteVideoFrame> {
556    use gpui::RenderImage;
557    use image::{Frame, RgbaImage};
558    use livekit::webrtc::prelude::VideoFormatType;
559    use smallvec::SmallVec;
560    use std::alloc::{Layout, alloc};
561
562    let width = buffer.width();
563    let height = buffer.height();
564    let stride = width * 4;
565    let byte_len = (stride * height) as usize;
566    let argb_image = unsafe {
567        // Motivation for this unsafe code is to avoid initializing the frame data, since to_argb
568        // will write all bytes anyway.
569        let start_ptr = alloc(Layout::array::<u8>(byte_len).log_err()?);
570        if start_ptr.is_null() {
571            return None;
572        }
573        let argb_frame_slice = std::slice::from_raw_parts_mut(start_ptr, byte_len);
574        buffer.to_argb(
575            VideoFormatType::ARGB,
576            argb_frame_slice,
577            stride,
578            width as i32,
579            height as i32,
580        );
581        Vec::from_raw_parts(start_ptr, byte_len, byte_len)
582    };
583
584    // TODO: Unclear why providing argb_image to RgbaImage works properly.
585    let image = RgbaImage::from_raw(width, height, argb_image)
586        .with_context(|| "Bug: not enough bytes allocated for image.")
587        .log_err()?;
588
589    Some(Arc::new(RenderImage::new(SmallVec::from_elem(
590        Frame::new(image),
591        1,
592    ))))
593}
594
595#[cfg(target_os = "macos")]
596fn video_frame_buffer_to_webrtc(frame: ScreenCaptureFrame) -> Option<impl AsRef<dyn VideoBuffer>> {
597    use livekit::webrtc;
598
599    let pixel_buffer = frame.0.as_concrete_TypeRef();
600    std::mem::forget(frame.0);
601    unsafe {
602        Some(webrtc::video_frame::native::NativeBuffer::from_cv_pixel_buffer(pixel_buffer as _))
603    }
604}
605
606#[cfg(not(target_os = "macos"))]
607fn video_frame_buffer_to_webrtc(frame: ScreenCaptureFrame) -> Option<impl AsRef<dyn VideoBuffer>> {
608    use libwebrtc::native::yuv_helper::{abgr_to_nv12, argb_to_nv12};
609    use livekit::webrtc::prelude::NV12Buffer;
610    match frame.0 {
611        scap::frame::Frame::BGRx(frame) => {
612            let mut buffer = NV12Buffer::new(frame.width as u32, frame.height as u32);
613            let (stride_y, stride_uv) = buffer.strides();
614            let (data_y, data_uv) = buffer.data_mut();
615            argb_to_nv12(
616                &frame.data,
617                frame.width as u32 * 4,
618                data_y,
619                stride_y,
620                data_uv,
621                stride_uv,
622                frame.width,
623                frame.height,
624            );
625            Some(buffer)
626        }
627        scap::frame::Frame::RGBx(frame) => {
628            let mut buffer = NV12Buffer::new(frame.width as u32, frame.height as u32);
629            let (stride_y, stride_uv) = buffer.strides();
630            let (data_y, data_uv) = buffer.data_mut();
631            abgr_to_nv12(
632                &frame.data,
633                frame.width as u32 * 4,
634                data_y,
635                stride_y,
636                data_uv,
637                stride_uv,
638                frame.width,
639                frame.height,
640            );
641            Some(buffer)
642        }
643        scap::frame::Frame::YUVFrame(yuvframe) => {
644            let mut buffer = NV12Buffer::with_strides(
645                yuvframe.width as u32,
646                yuvframe.height as u32,
647                yuvframe.luminance_stride as u32,
648                yuvframe.chrominance_stride as u32,
649            );
650            let (luminance, chrominance) = buffer.data_mut();
651            luminance.copy_from_slice(yuvframe.luminance_bytes.as_slice());
652            chrominance.copy_from_slice(yuvframe.chrominance_bytes.as_slice());
653            Some(buffer)
654        }
655        _ => {
656            log::error!(
657                "Expected BGRx or YUV frame from scap screen capture but got some other format."
658            );
659            None
660        }
661    }
662}
663
664trait DeviceChangeListenerApi: Stream<Item = ()> + Sized {
665    fn new(input: bool) -> Result<Self>;
666}
667
668#[cfg(target_os = "macos")]
669mod macos {
670
671    use coreaudio::sys::{
672        AudioObjectAddPropertyListener, AudioObjectID, AudioObjectPropertyAddress,
673        AudioObjectRemovePropertyListener, OSStatus, kAudioHardwarePropertyDefaultInputDevice,
674        kAudioHardwarePropertyDefaultOutputDevice, kAudioObjectPropertyElementMaster,
675        kAudioObjectPropertyScopeGlobal, kAudioObjectSystemObject,
676    };
677    use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
678
679    /// Implementation from: https://github.com/zed-industries/cpal/blob/fd8bc2fd39f1f5fdee5a0690656caff9a26d9d50/src/host/coreaudio/macos/property_listener.rs#L15
680    pub struct CoreAudioDefaultDeviceChangeListener {
681        rx: UnboundedReceiver<()>,
682        callback: Box<PropertyListenerCallbackWrapper>,
683        input: bool,
684        device_id: AudioObjectID, // Store the device ID to properly remove listeners
685    }
686
687    trait _AssertSend: Send {}
688    impl _AssertSend for CoreAudioDefaultDeviceChangeListener {}
689
690    struct PropertyListenerCallbackWrapper(Box<dyn FnMut() + Send>);
691
692    unsafe extern "C" fn property_listener_handler_shim(
693        _: AudioObjectID,
694        _: u32,
695        _: *const AudioObjectPropertyAddress,
696        callback: *mut ::std::os::raw::c_void,
697    ) -> OSStatus {
698        let wrapper = callback as *mut PropertyListenerCallbackWrapper;
699        unsafe { (*wrapper).0() };
700        0
701    }
702
703    impl super::DeviceChangeListenerApi for CoreAudioDefaultDeviceChangeListener {
704        fn new(input: bool) -> anyhow::Result<Self> {
705            let (tx, rx) = futures::channel::mpsc::unbounded();
706
707            let callback = Box::new(PropertyListenerCallbackWrapper(Box::new(move || {
708                tx.unbounded_send(()).ok();
709            })));
710
711            // Get the current default device ID
712            let device_id = unsafe {
713                // Listen for default device changes
714                coreaudio::Error::from_os_status(AudioObjectAddPropertyListener(
715                    kAudioObjectSystemObject,
716                    &AudioObjectPropertyAddress {
717                        mSelector: if input {
718                            kAudioHardwarePropertyDefaultInputDevice
719                        } else {
720                            kAudioHardwarePropertyDefaultOutputDevice
721                        },
722                        mScope: kAudioObjectPropertyScopeGlobal,
723                        mElement: kAudioObjectPropertyElementMaster,
724                    },
725                    Some(property_listener_handler_shim),
726                    &*callback as *const _ as *mut _,
727                ))?;
728
729                // Also listen for changes to the device configuration
730                let device_id = if input {
731                    let mut input_device: AudioObjectID = 0;
732                    let mut prop_size = std::mem::size_of::<AudioObjectID>() as u32;
733                    let result = coreaudio::sys::AudioObjectGetPropertyData(
734                        kAudioObjectSystemObject,
735                        &AudioObjectPropertyAddress {
736                            mSelector: kAudioHardwarePropertyDefaultInputDevice,
737                            mScope: kAudioObjectPropertyScopeGlobal,
738                            mElement: kAudioObjectPropertyElementMaster,
739                        },
740                        0,
741                        std::ptr::null(),
742                        &mut prop_size as *mut _,
743                        &mut input_device as *mut _ as *mut _,
744                    );
745                    if result != 0 {
746                        log::warn!("Failed to get default input device ID");
747                        0
748                    } else {
749                        input_device
750                    }
751                } else {
752                    let mut output_device: AudioObjectID = 0;
753                    let mut prop_size = std::mem::size_of::<AudioObjectID>() as u32;
754                    let result = coreaudio::sys::AudioObjectGetPropertyData(
755                        kAudioObjectSystemObject,
756                        &AudioObjectPropertyAddress {
757                            mSelector: kAudioHardwarePropertyDefaultOutputDevice,
758                            mScope: kAudioObjectPropertyScopeGlobal,
759                            mElement: kAudioObjectPropertyElementMaster,
760                        },
761                        0,
762                        std::ptr::null(),
763                        &mut prop_size as *mut _,
764                        &mut output_device as *mut _ as *mut _,
765                    );
766                    if result != 0 {
767                        log::warn!("Failed to get default output device ID");
768                        0
769                    } else {
770                        output_device
771                    }
772                };
773
774                if device_id != 0 {
775                    // Listen for format changes on the device
776                    coreaudio::Error::from_os_status(AudioObjectAddPropertyListener(
777                        device_id,
778                        &AudioObjectPropertyAddress {
779                            mSelector: coreaudio::sys::kAudioDevicePropertyStreamFormat,
780                            mScope: if input {
781                                coreaudio::sys::kAudioObjectPropertyScopeInput
782                            } else {
783                                coreaudio::sys::kAudioObjectPropertyScopeOutput
784                            },
785                            mElement: kAudioObjectPropertyElementMaster,
786                        },
787                        Some(property_listener_handler_shim),
788                        &*callback as *const _ as *mut _,
789                    ))?;
790                }
791
792                device_id
793            };
794
795            Ok(Self {
796                rx,
797                callback,
798                input,
799                device_id,
800            })
801        }
802    }
803
804    impl Drop for CoreAudioDefaultDeviceChangeListener {
805        fn drop(&mut self) {
806            unsafe {
807                // Remove the system-level property listener
808                AudioObjectRemovePropertyListener(
809                    kAudioObjectSystemObject,
810                    &AudioObjectPropertyAddress {
811                        mSelector: if self.input {
812                            kAudioHardwarePropertyDefaultInputDevice
813                        } else {
814                            kAudioHardwarePropertyDefaultOutputDevice
815                        },
816                        mScope: kAudioObjectPropertyScopeGlobal,
817                        mElement: kAudioObjectPropertyElementMaster,
818                    },
819                    Some(property_listener_handler_shim),
820                    &*self.callback as *const _ as *mut _,
821                );
822
823                // Remove the device-specific property listener if we have a valid device ID
824                if self.device_id != 0 {
825                    AudioObjectRemovePropertyListener(
826                        self.device_id,
827                        &AudioObjectPropertyAddress {
828                            mSelector: coreaudio::sys::kAudioDevicePropertyStreamFormat,
829                            mScope: if self.input {
830                                coreaudio::sys::kAudioObjectPropertyScopeInput
831                            } else {
832                                coreaudio::sys::kAudioObjectPropertyScopeOutput
833                            },
834                            mElement: kAudioObjectPropertyElementMaster,
835                        },
836                        Some(property_listener_handler_shim),
837                        &*self.callback as *const _ as *mut _,
838                    );
839                }
840            }
841        }
842    }
843
844    impl futures::Stream for CoreAudioDefaultDeviceChangeListener {
845        type Item = ();
846
847        fn poll_next(
848            mut self: std::pin::Pin<&mut Self>,
849            cx: &mut std::task::Context<'_>,
850        ) -> std::task::Poll<Option<Self::Item>> {
851            self.rx.poll_next_unpin(cx)
852        }
853    }
854}
855
856#[cfg(target_os = "macos")]
857type DeviceChangeListener = macos::CoreAudioDefaultDeviceChangeListener;
858
859#[cfg(not(target_os = "macos"))]
860mod noop_change_listener {
861    use std::task::Poll;
862
863    use super::DeviceChangeListenerApi;
864
865    pub struct NoopOutputDeviceChangelistener {}
866
867    impl DeviceChangeListenerApi for NoopOutputDeviceChangelistener {
868        fn new(_input: bool) -> anyhow::Result<Self> {
869            Ok(NoopOutputDeviceChangelistener {})
870        }
871    }
872
873    impl futures::Stream for NoopOutputDeviceChangelistener {
874        type Item = ();
875
876        fn poll_next(
877            self: std::pin::Pin<&mut Self>,
878            _cx: &mut std::task::Context<'_>,
879        ) -> Poll<Option<Self::Item>> {
880            Poll::Pending
881        }
882    }
883}
884
885#[cfg(not(target_os = "macos"))]
886type DeviceChangeListener = noop_change_listener::NoopOutputDeviceChangelistener;