playback.rs

  1use anyhow::{Context as _, Result};
  2
  3use cpal::traits::{DeviceTrait, StreamTrait as _};
  4use futures::channel::mpsc::UnboundedSender;
  5use futures::{Stream, StreamExt as _};
  6use gpui::{
  7    BackgroundExecutor, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, Task,
  8};
  9use libwebrtc::native::{apm, audio_mixer, audio_resampler};
 10use livekit::track;
 11
 12use livekit::webrtc::{
 13    audio_frame::AudioFrame,
 14    audio_source::{AudioSourceOptions, RtcAudioSource, native::NativeAudioSource},
 15    audio_stream::native::NativeAudioStream,
 16    video_frame::{VideoBuffer, VideoFrame, VideoRotation},
 17    video_source::{RtcVideoSource, VideoResolution, native::NativeVideoSource},
 18    video_stream::native::NativeVideoStream,
 19};
 20use parking_lot::Mutex;
 21use std::cell::RefCell;
 22use std::sync::Weak;
 23use std::sync::atomic::{self, AtomicI32};
 24use std::time::Duration;
 25use std::{borrow::Cow, collections::VecDeque, sync::Arc, thread};
 26use util::{ResultExt as _, maybe};
 27
 28pub(crate) struct AudioStack {
 29    executor: BackgroundExecutor,
 30    apm: Arc<Mutex<apm::AudioProcessingModule>>,
 31    mixer: Arc<Mutex<audio_mixer::AudioMixer>>,
 32    _output_task: RefCell<Weak<Task<()>>>,
 33    next_ssrc: AtomicI32,
 34}
 35
 36// NOTE: We use WebRTC's mixer which only supports
 37// 16kHz, 32kHz and 48kHz. As 48 is the most common "next step up"
 38// for audio output devices like speakers/bluetooth, we just hard-code
 39// this; and downsample when we need to.
 40const SAMPLE_RATE: u32 = 48000;
 41const NUM_CHANNELS: u32 = 2;
 42
 43impl AudioStack {
 44    pub(crate) fn new(executor: BackgroundExecutor) -> Self {
 45        let apm = Arc::new(Mutex::new(apm::AudioProcessingModule::new(
 46            true, true, true, true,
 47        )));
 48        let mixer = Arc::new(Mutex::new(audio_mixer::AudioMixer::new()));
 49        Self {
 50            executor,
 51            apm,
 52            mixer,
 53            _output_task: RefCell::new(Weak::new()),
 54            next_ssrc: AtomicI32::new(1),
 55        }
 56    }
 57
 58    pub(crate) fn play_remote_audio_track(
 59        &self,
 60        track: &livekit::track::RemoteAudioTrack,
 61    ) -> AudioStream {
 62        let output_task = self.start_output();
 63
 64        let next_ssrc = self.next_ssrc.fetch_add(1, atomic::Ordering::Relaxed);
 65        let source = AudioMixerSource {
 66            ssrc: next_ssrc,
 67            sample_rate: SAMPLE_RATE,
 68            num_channels: NUM_CHANNELS,
 69            buffer: Arc::default(),
 70        };
 71        self.mixer.lock().add_source(source.clone());
 72
 73        let mut stream = NativeAudioStream::new(
 74            track.rtc_track(),
 75            source.sample_rate as i32,
 76            source.num_channels as i32,
 77        );
 78
 79        let receive_task = self.executor.spawn({
 80            let source = source.clone();
 81            async move {
 82                while let Some(frame) = stream.next().await {
 83                    source.receive(frame);
 84                }
 85            }
 86        });
 87
 88        let mixer = self.mixer.clone();
 89        let on_drop = util::defer(move || {
 90            mixer.lock().remove_source(source.ssrc);
 91            drop(receive_task);
 92            drop(output_task);
 93        });
 94
 95        AudioStream::Output {
 96            _drop: Box::new(on_drop),
 97        }
 98    }
 99
100    pub(crate) fn capture_local_microphone_track(
101        &self,
102    ) -> Result<(crate::LocalAudioTrack, AudioStream)> {
103        let source = NativeAudioSource::new(
104            // n.b. this struct's options are always ignored, noise cancellation is provided by apm.
105            AudioSourceOptions::default(),
106            SAMPLE_RATE,
107            NUM_CHANNELS,
108            10,
109        );
110
111        let track = track::LocalAudioTrack::create_audio_track(
112            "microphone",
113            RtcAudioSource::Native(source.clone()),
114        );
115
116        let apm = self.apm.clone();
117
118        let (frame_tx, mut frame_rx) = futures::channel::mpsc::unbounded();
119        let transmit_task = self.executor.spawn({
120            async move {
121                while let Some(frame) = frame_rx.next().await {
122                    source.capture_frame(&frame).await.log_err();
123                }
124            }
125        });
126        let capture_task = self.executor.spawn(async move {
127            Self::capture_input(apm, frame_tx, SAMPLE_RATE, NUM_CHANNELS).await
128        });
129
130        let on_drop = util::defer(|| {
131            drop(transmit_task);
132            drop(capture_task);
133        });
134        Ok((
135            super::LocalAudioTrack(track),
136            AudioStream::Output {
137                _drop: Box::new(on_drop),
138            },
139        ))
140    }
141
142    fn start_output(&self) -> Arc<Task<()>> {
143        if let Some(task) = self._output_task.borrow().upgrade() {
144            return task;
145        }
146        let task = Arc::new(self.executor.spawn({
147            let apm = self.apm.clone();
148            let mixer = self.mixer.clone();
149            async move {
150                Self::play_output(apm, mixer, SAMPLE_RATE, NUM_CHANNELS)
151                    .await
152                    .log_err();
153            }
154        }));
155        *self._output_task.borrow_mut() = Arc::downgrade(&task);
156        task
157    }
158
159    async fn play_output(
160        apm: Arc<Mutex<apm::AudioProcessingModule>>,
161        mixer: Arc<Mutex<audio_mixer::AudioMixer>>,
162        sample_rate: u32,
163        num_channels: u32,
164    ) -> Result<()> {
165        loop {
166            let mut device_change_listener = DeviceChangeListener::new(false)?;
167            let (output_device, output_config) = crate::default_device(false)?;
168            let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>();
169            let mixer = mixer.clone();
170            let apm = apm.clone();
171            let mut resampler = audio_resampler::AudioResampler::default();
172            let mut buf = Vec::new();
173
174            thread::spawn(move || {
175                let output_stream = output_device.build_output_stream(
176                    &output_config.config(),
177                    {
178                        move |mut data, _info| {
179                            while data.len() > 0 {
180                                if data.len() <= buf.len() {
181                                    let rest = buf.split_off(data.len());
182                                    data.copy_from_slice(&buf);
183                                    buf = rest;
184                                    return;
185                                }
186                                if buf.len() > 0 {
187                                    let (prefix, suffix) = data.split_at_mut(buf.len());
188                                    prefix.copy_from_slice(&buf);
189                                    data = suffix;
190                                }
191
192                                let mut mixer = mixer.lock();
193                                let mixed = mixer.mix(output_config.channels() as usize);
194                                let sampled = resampler.remix_and_resample(
195                                    mixed,
196                                    sample_rate / 100,
197                                    num_channels,
198                                    sample_rate,
199                                    output_config.channels() as u32,
200                                    output_config.sample_rate().0,
201                                );
202                                buf = sampled.to_vec();
203                                apm.lock()
204                                    .process_reverse_stream(
205                                        &mut buf,
206                                        output_config.sample_rate().0 as i32,
207                                        output_config.channels() as i32,
208                                    )
209                                    .ok();
210                            }
211                        }
212                    },
213                    |error| log::error!("error playing audio track: {:?}", error),
214                    Some(Duration::from_millis(100)),
215                );
216
217                let Some(output_stream) = output_stream.log_err() else {
218                    return;
219                };
220
221                output_stream.play().log_err();
222                // Block forever to keep the output stream alive
223                end_on_drop_rx.recv().ok();
224            });
225
226            device_change_listener.next().await;
227            drop(end_on_drop_tx)
228        }
229    }
230
231    async fn capture_input(
232        apm: Arc<Mutex<apm::AudioProcessingModule>>,
233        frame_tx: UnboundedSender<AudioFrame<'static>>,
234        sample_rate: u32,
235        num_channels: u32,
236    ) -> Result<()> {
237        loop {
238            let mut device_change_listener = DeviceChangeListener::new(true)?;
239            let (device, config) = crate::default_device(true)?;
240            let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>();
241            let apm = apm.clone();
242            let frame_tx = frame_tx.clone();
243            let mut resampler = audio_resampler::AudioResampler::default();
244
245            thread::spawn(move || {
246                maybe!({
247                    if let Some(name) = device.name().ok() {
248                        log::info!("Using microphone: {}", name)
249                    } else {
250                        log::info!("Using microphone: <unknown>");
251                    }
252
253                    let ten_ms_buffer_size =
254                        (config.channels() as u32 * config.sample_rate().0 / 100) as usize;
255                    let mut buf: Vec<i16> = Vec::with_capacity(ten_ms_buffer_size);
256
257                    let stream = device
258                        .build_input_stream_raw(
259                            &config.config(),
260                            config.sample_format(),
261                            move |data, _: &_| {
262                                let data =
263                                    crate::get_sample_data(config.sample_format(), data).log_err();
264                                let Some(data) = data else {
265                                    return;
266                                };
267                                let mut data = data.as_slice();
268
269                                while data.len() > 0 {
270                                    let remainder = (buf.capacity() - buf.len()).min(data.len());
271                                    buf.extend_from_slice(&data[..remainder]);
272                                    data = &data[remainder..];
273
274                                    if buf.capacity() == buf.len() {
275                                        let mut sampled = resampler
276                                            .remix_and_resample(
277                                                buf.as_slice(),
278                                                config.sample_rate().0 / 100,
279                                                config.channels() as u32,
280                                                config.sample_rate().0,
281                                                num_channels,
282                                                sample_rate,
283                                            )
284                                            .to_owned();
285                                        apm.lock()
286                                            .process_stream(
287                                                &mut sampled,
288                                                sample_rate as i32,
289                                                num_channels as i32,
290                                            )
291                                            .log_err();
292                                        buf.clear();
293                                        frame_tx
294                                            .unbounded_send(AudioFrame {
295                                                data: Cow::Owned(sampled),
296                                                sample_rate,
297                                                num_channels,
298                                                samples_per_channel: sample_rate / 100,
299                                            })
300                                            .ok();
301                                    }
302                                }
303                            },
304                            |err| log::error!("error capturing audio track: {:?}", err),
305                            Some(Duration::from_millis(100)),
306                        )
307                        .context("failed to build input stream")?;
308
309                    stream.play()?;
310                    // Keep the thread alive and holding onto the `stream`
311                    end_on_drop_rx.recv().ok();
312                    anyhow::Ok(Some(()))
313                })
314                .log_err();
315            });
316
317            device_change_listener.next().await;
318            drop(end_on_drop_tx)
319        }
320    }
321}
322
323use super::LocalVideoTrack;
324
325pub enum AudioStream {
326    Input { _task: Task<()> },
327    Output { _drop: Box<dyn std::any::Any> },
328}
329
330pub(crate) async fn capture_local_video_track(
331    capture_source: &dyn ScreenCaptureSource,
332    cx: &mut gpui::AsyncApp,
333) -> Result<(crate::LocalVideoTrack, Box<dyn ScreenCaptureStream>)> {
334    let metadata = capture_source.metadata()?;
335    let track_source = gpui_tokio::Tokio::spawn(cx, async move {
336        NativeVideoSource::new(VideoResolution {
337            width: metadata.resolution.width.0 as u32,
338            height: metadata.resolution.height.0 as u32,
339        })
340    })?
341    .await?;
342
343    let capture_stream = capture_source
344        .stream(cx.foreground_executor(), {
345            let track_source = track_source.clone();
346            Box::new(move |frame| {
347                if let Some(buffer) = video_frame_buffer_to_webrtc(frame) {
348                    track_source.capture_frame(&VideoFrame {
349                        rotation: VideoRotation::VideoRotation0,
350                        timestamp_us: 0,
351                        buffer,
352                    });
353                }
354            })
355        })
356        .await??;
357
358    Ok((
359        LocalVideoTrack(track::LocalVideoTrack::create_video_track(
360            "screen share",
361            RtcVideoSource::Native(track_source),
362        )),
363        capture_stream,
364    ))
365}
366
367#[derive(Clone)]
368struct AudioMixerSource {
369    ssrc: i32,
370    sample_rate: u32,
371    num_channels: u32,
372    buffer: Arc<Mutex<VecDeque<Vec<i16>>>>,
373}
374
375impl AudioMixerSource {
376    fn receive(&self, frame: AudioFrame) {
377        assert_eq!(
378            frame.data.len() as u32,
379            self.sample_rate * self.num_channels / 100
380        );
381
382        let mut buffer = self.buffer.lock();
383        buffer.push_back(frame.data.to_vec());
384        while buffer.len() > 10 {
385            buffer.pop_front();
386        }
387    }
388}
389
390impl libwebrtc::native::audio_mixer::AudioMixerSource for AudioMixerSource {
391    fn ssrc(&self) -> i32 {
392        self.ssrc
393    }
394
395    fn preferred_sample_rate(&self) -> u32 {
396        self.sample_rate
397    }
398
399    fn get_audio_frame_with_info<'a>(&self, target_sample_rate: u32) -> Option<AudioFrame<'_>> {
400        assert_eq!(self.sample_rate, target_sample_rate);
401        let buf = self.buffer.lock().pop_front()?;
402        Some(AudioFrame {
403            data: Cow::Owned(buf),
404            sample_rate: self.sample_rate,
405            num_channels: self.num_channels,
406            samples_per_channel: self.sample_rate / 100,
407        })
408    }
409}
410
411pub fn play_remote_video_track(
412    track: &crate::RemoteVideoTrack,
413) -> impl Stream<Item = RemoteVideoFrame> + use<> {
414    #[cfg(target_os = "macos")]
415    {
416        let mut pool = None;
417        let most_recent_frame_size = (0, 0);
418        NativeVideoStream::new(track.0.rtc_track()).filter_map(move |frame| {
419            if pool == None
420                || most_recent_frame_size != (frame.buffer.width(), frame.buffer.height())
421            {
422                pool = create_buffer_pool(frame.buffer.width(), frame.buffer.height()).log_err();
423            }
424            let pool = pool.clone();
425            async move {
426                if frame.buffer.width() < 10 && frame.buffer.height() < 10 {
427                    // when the remote stops sharing, we get an 8x8 black image.
428                    // In a lil bit, the unpublish will come through and close the view,
429                    // but until then, don't flash black.
430                    return None;
431                }
432
433                video_frame_buffer_from_webrtc(pool?, frame.buffer)
434            }
435        })
436    }
437    #[cfg(not(target_os = "macos"))]
438    {
439        NativeVideoStream::new(track.0.rtc_track())
440            .filter_map(|frame| async move { video_frame_buffer_from_webrtc(frame.buffer) })
441    }
442}
443
444#[cfg(target_os = "macos")]
445fn create_buffer_pool(
446    width: u32,
447    height: u32,
448) -> Result<core_video::pixel_buffer_pool::CVPixelBufferPool> {
449    use core_foundation::{base::TCFType, number::CFNumber, string::CFString};
450    use core_video::pixel_buffer;
451    use core_video::{
452        pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange,
453        pixel_buffer_io_surface::kCVPixelBufferIOSurfaceCoreAnimationCompatibilityKey,
454        pixel_buffer_pool::{self},
455    };
456
457    let width_key: CFString =
458        unsafe { CFString::wrap_under_get_rule(pixel_buffer::kCVPixelBufferWidthKey) };
459    let height_key: CFString =
460        unsafe { CFString::wrap_under_get_rule(pixel_buffer::kCVPixelBufferHeightKey) };
461    let animation_key: CFString = unsafe {
462        CFString::wrap_under_get_rule(kCVPixelBufferIOSurfaceCoreAnimationCompatibilityKey)
463    };
464    let format_key: CFString =
465        unsafe { CFString::wrap_under_get_rule(pixel_buffer::kCVPixelBufferPixelFormatTypeKey) };
466
467    let yes: CFNumber = 1.into();
468    let width: CFNumber = (width as i32).into();
469    let height: CFNumber = (height as i32).into();
470    let format: CFNumber = (kCVPixelFormatType_420YpCbCr8BiPlanarFullRange as i64).into();
471
472    let buffer_attributes = core_foundation::dictionary::CFDictionary::from_CFType_pairs(&[
473        (width_key, width.into_CFType()),
474        (height_key, height.into_CFType()),
475        (animation_key, yes.into_CFType()),
476        (format_key, format.into_CFType()),
477    ]);
478
479    pixel_buffer_pool::CVPixelBufferPool::new(None, Some(&buffer_attributes)).map_err(|cv_return| {
480        anyhow::anyhow!("failed to create pixel buffer pool: CVReturn({cv_return})",)
481    })
482}
483
484#[cfg(target_os = "macos")]
485pub type RemoteVideoFrame = core_video::pixel_buffer::CVPixelBuffer;
486
487#[cfg(target_os = "macos")]
488fn video_frame_buffer_from_webrtc(
489    pool: core_video::pixel_buffer_pool::CVPixelBufferPool,
490    buffer: Box<dyn VideoBuffer>,
491) -> Option<RemoteVideoFrame> {
492    use core_foundation::base::TCFType;
493    use core_video::{pixel_buffer::CVPixelBuffer, r#return::kCVReturnSuccess};
494    use livekit::webrtc::native::yuv_helper::i420_to_nv12;
495
496    if let Some(native) = buffer.as_native() {
497        let pixel_buffer = native.get_cv_pixel_buffer();
498        if pixel_buffer.is_null() {
499            return None;
500        }
501        return unsafe { Some(CVPixelBuffer::wrap_under_get_rule(pixel_buffer as _)) };
502    }
503
504    let i420_buffer = buffer.as_i420()?;
505    let pixel_buffer = pool.create_pixel_buffer().log_err()?;
506
507    let image_buffer = unsafe {
508        if pixel_buffer.lock_base_address(0) != kCVReturnSuccess {
509            return None;
510        }
511
512        let dst_y = pixel_buffer.get_base_address_of_plane(0);
513        let dst_y_stride = pixel_buffer.get_bytes_per_row_of_plane(0);
514        let dst_y_len = pixel_buffer.get_height_of_plane(0) * dst_y_stride;
515        let dst_uv = pixel_buffer.get_base_address_of_plane(1);
516        let dst_uv_stride = pixel_buffer.get_bytes_per_row_of_plane(1);
517        let dst_uv_len = pixel_buffer.get_height_of_plane(1) * dst_uv_stride;
518        let width = pixel_buffer.get_width();
519        let height = pixel_buffer.get_height();
520        let dst_y_buffer = std::slice::from_raw_parts_mut(dst_y as *mut u8, dst_y_len);
521        let dst_uv_buffer = std::slice::from_raw_parts_mut(dst_uv as *mut u8, dst_uv_len);
522
523        let (stride_y, stride_u, stride_v) = i420_buffer.strides();
524        let (src_y, src_u, src_v) = i420_buffer.data();
525        i420_to_nv12(
526            src_y,
527            stride_y,
528            src_u,
529            stride_u,
530            src_v,
531            stride_v,
532            dst_y_buffer,
533            dst_y_stride as u32,
534            dst_uv_buffer,
535            dst_uv_stride as u32,
536            width as i32,
537            height as i32,
538        );
539
540        if pixel_buffer.unlock_base_address(0) != kCVReturnSuccess {
541            return None;
542        }
543
544        pixel_buffer
545    };
546
547    Some(image_buffer)
548}
549
550#[cfg(not(target_os = "macos"))]
551pub type RemoteVideoFrame = Arc<gpui::RenderImage>;
552
553#[cfg(not(target_os = "macos"))]
554fn video_frame_buffer_from_webrtc(buffer: Box<dyn VideoBuffer>) -> Option<RemoteVideoFrame> {
555    use gpui::RenderImage;
556    use image::{Frame, RgbaImage};
557    use livekit::webrtc::prelude::VideoFormatType;
558    use smallvec::SmallVec;
559    use std::alloc::{Layout, alloc};
560
561    let width = buffer.width();
562    let height = buffer.height();
563    let stride = width * 4;
564    let byte_len = (stride * height) as usize;
565    let argb_image = unsafe {
566        // Motivation for this unsafe code is to avoid initializing the frame data, since to_argb
567        // will write all bytes anyway.
568        let start_ptr = alloc(Layout::array::<u8>(byte_len).log_err()?);
569        if start_ptr.is_null() {
570            return None;
571        }
572        let argb_frame_slice = std::slice::from_raw_parts_mut(start_ptr, byte_len);
573        buffer.to_argb(
574            VideoFormatType::ARGB,
575            argb_frame_slice,
576            stride,
577            width as i32,
578            height as i32,
579        );
580        Vec::from_raw_parts(start_ptr, byte_len, byte_len)
581    };
582
583    // TODO: Unclear why providing argb_image to RgbaImage works properly.
584    let image = RgbaImage::from_raw(width, height, argb_image)
585        .with_context(|| "Bug: not enough bytes allocated for image.")
586        .log_err()?;
587
588    Some(Arc::new(RenderImage::new(SmallVec::from_elem(
589        Frame::new(image),
590        1,
591    ))))
592}
593
594#[cfg(target_os = "macos")]
595fn video_frame_buffer_to_webrtc(frame: ScreenCaptureFrame) -> Option<impl AsRef<dyn VideoBuffer>> {
596    use livekit::webrtc;
597
598    let pixel_buffer = frame.0.as_concrete_TypeRef();
599    std::mem::forget(frame.0);
600    unsafe {
601        Some(webrtc::video_frame::native::NativeBuffer::from_cv_pixel_buffer(pixel_buffer as _))
602    }
603}
604
605#[cfg(not(target_os = "macos"))]
606fn video_frame_buffer_to_webrtc(frame: ScreenCaptureFrame) -> Option<impl AsRef<dyn VideoBuffer>> {
607    use libwebrtc::native::yuv_helper::{abgr_to_nv12, argb_to_nv12};
608    use livekit::webrtc::prelude::NV12Buffer;
609    match frame.0 {
610        scap::frame::Frame::BGRx(frame) => {
611            let mut buffer = NV12Buffer::new(frame.width as u32, frame.height as u32);
612            let (stride_y, stride_uv) = buffer.strides();
613            let (data_y, data_uv) = buffer.data_mut();
614            argb_to_nv12(
615                &frame.data,
616                frame.width as u32 * 4,
617                data_y,
618                stride_y,
619                data_uv,
620                stride_uv,
621                frame.width,
622                frame.height,
623            );
624            Some(buffer)
625        }
626        scap::frame::Frame::RGBx(frame) => {
627            let mut buffer = NV12Buffer::new(frame.width as u32, frame.height as u32);
628            let (stride_y, stride_uv) = buffer.strides();
629            let (data_y, data_uv) = buffer.data_mut();
630            abgr_to_nv12(
631                &frame.data,
632                frame.width as u32 * 4,
633                data_y,
634                stride_y,
635                data_uv,
636                stride_uv,
637                frame.width,
638                frame.height,
639            );
640            Some(buffer)
641        }
642        scap::frame::Frame::YUVFrame(yuvframe) => {
643            let mut buffer = NV12Buffer::with_strides(
644                yuvframe.width as u32,
645                yuvframe.height as u32,
646                yuvframe.luminance_stride as u32,
647                yuvframe.chrominance_stride as u32,
648            );
649            let (luminance, chrominance) = buffer.data_mut();
650            luminance.copy_from_slice(yuvframe.luminance_bytes.as_slice());
651            chrominance.copy_from_slice(yuvframe.chrominance_bytes.as_slice());
652            Some(buffer)
653        }
654        _ => {
655            log::error!(
656                "Expected BGRx or YUV frame from scap screen capture but got some other format."
657            );
658            None
659        }
660    }
661}
662
663trait DeviceChangeListenerApi: Stream<Item = ()> + Sized {
664    fn new(input: bool) -> Result<Self>;
665}
666
667#[cfg(target_os = "macos")]
668mod macos {
669
670    use coreaudio::sys::{
671        AudioObjectAddPropertyListener, AudioObjectID, AudioObjectPropertyAddress,
672        AudioObjectRemovePropertyListener, OSStatus, kAudioHardwarePropertyDefaultInputDevice,
673        kAudioHardwarePropertyDefaultOutputDevice, kAudioObjectPropertyElementMaster,
674        kAudioObjectPropertyScopeGlobal, kAudioObjectSystemObject,
675    };
676    use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
677
678    /// Implementation from: https://github.com/zed-industries/cpal/blob/fd8bc2fd39f1f5fdee5a0690656caff9a26d9d50/src/host/coreaudio/macos/property_listener.rs#L15
679    pub struct CoreAudioDefaultDeviceChangeListener {
680        rx: UnboundedReceiver<()>,
681        callback: Box<PropertyListenerCallbackWrapper>,
682        input: bool,
683        device_id: AudioObjectID, // Store the device ID to properly remove listeners
684    }
685
686    trait _AssertSend: Send {}
687    impl _AssertSend for CoreAudioDefaultDeviceChangeListener {}
688
689    struct PropertyListenerCallbackWrapper(Box<dyn FnMut() + Send>);
690
691    unsafe extern "C" fn property_listener_handler_shim(
692        _: AudioObjectID,
693        _: u32,
694        _: *const AudioObjectPropertyAddress,
695        callback: *mut ::std::os::raw::c_void,
696    ) -> OSStatus {
697        let wrapper = callback as *mut PropertyListenerCallbackWrapper;
698        unsafe { (*wrapper).0() };
699        0
700    }
701
702    impl super::DeviceChangeListenerApi for CoreAudioDefaultDeviceChangeListener {
703        fn new(input: bool) -> anyhow::Result<Self> {
704            let (tx, rx) = futures::channel::mpsc::unbounded();
705
706            let callback = Box::new(PropertyListenerCallbackWrapper(Box::new(move || {
707                tx.unbounded_send(()).ok();
708            })));
709
710            // Get the current default device ID
711            let device_id = unsafe {
712                // Listen for default device changes
713                coreaudio::Error::from_os_status(AudioObjectAddPropertyListener(
714                    kAudioObjectSystemObject,
715                    &AudioObjectPropertyAddress {
716                        mSelector: if input {
717                            kAudioHardwarePropertyDefaultInputDevice
718                        } else {
719                            kAudioHardwarePropertyDefaultOutputDevice
720                        },
721                        mScope: kAudioObjectPropertyScopeGlobal,
722                        mElement: kAudioObjectPropertyElementMaster,
723                    },
724                    Some(property_listener_handler_shim),
725                    &*callback as *const _ as *mut _,
726                ))?;
727
728                // Also listen for changes to the device configuration
729                let device_id = if input {
730                    let mut input_device: AudioObjectID = 0;
731                    let mut prop_size = std::mem::size_of::<AudioObjectID>() as u32;
732                    let result = coreaudio::sys::AudioObjectGetPropertyData(
733                        kAudioObjectSystemObject,
734                        &AudioObjectPropertyAddress {
735                            mSelector: kAudioHardwarePropertyDefaultInputDevice,
736                            mScope: kAudioObjectPropertyScopeGlobal,
737                            mElement: kAudioObjectPropertyElementMaster,
738                        },
739                        0,
740                        std::ptr::null(),
741                        &mut prop_size as *mut _,
742                        &mut input_device as *mut _ as *mut _,
743                    );
744                    if result != 0 {
745                        log::warn!("Failed to get default input device ID");
746                        0
747                    } else {
748                        input_device
749                    }
750                } else {
751                    let mut output_device: AudioObjectID = 0;
752                    let mut prop_size = std::mem::size_of::<AudioObjectID>() as u32;
753                    let result = coreaudio::sys::AudioObjectGetPropertyData(
754                        kAudioObjectSystemObject,
755                        &AudioObjectPropertyAddress {
756                            mSelector: kAudioHardwarePropertyDefaultOutputDevice,
757                            mScope: kAudioObjectPropertyScopeGlobal,
758                            mElement: kAudioObjectPropertyElementMaster,
759                        },
760                        0,
761                        std::ptr::null(),
762                        &mut prop_size as *mut _,
763                        &mut output_device as *mut _ as *mut _,
764                    );
765                    if result != 0 {
766                        log::warn!("Failed to get default output device ID");
767                        0
768                    } else {
769                        output_device
770                    }
771                };
772
773                if device_id != 0 {
774                    // Listen for format changes on the device
775                    coreaudio::Error::from_os_status(AudioObjectAddPropertyListener(
776                        device_id,
777                        &AudioObjectPropertyAddress {
778                            mSelector: coreaudio::sys::kAudioDevicePropertyStreamFormat,
779                            mScope: if input {
780                                coreaudio::sys::kAudioObjectPropertyScopeInput
781                            } else {
782                                coreaudio::sys::kAudioObjectPropertyScopeOutput
783                            },
784                            mElement: kAudioObjectPropertyElementMaster,
785                        },
786                        Some(property_listener_handler_shim),
787                        &*callback as *const _ as *mut _,
788                    ))?;
789                }
790
791                device_id
792            };
793
794            Ok(Self {
795                rx,
796                callback,
797                input,
798                device_id,
799            })
800        }
801    }
802
803    impl Drop for CoreAudioDefaultDeviceChangeListener {
804        fn drop(&mut self) {
805            unsafe {
806                // Remove the system-level property listener
807                AudioObjectRemovePropertyListener(
808                    kAudioObjectSystemObject,
809                    &AudioObjectPropertyAddress {
810                        mSelector: if self.input {
811                            kAudioHardwarePropertyDefaultInputDevice
812                        } else {
813                            kAudioHardwarePropertyDefaultOutputDevice
814                        },
815                        mScope: kAudioObjectPropertyScopeGlobal,
816                        mElement: kAudioObjectPropertyElementMaster,
817                    },
818                    Some(property_listener_handler_shim),
819                    &*self.callback as *const _ as *mut _,
820                );
821
822                // Remove the device-specific property listener if we have a valid device ID
823                if self.device_id != 0 {
824                    AudioObjectRemovePropertyListener(
825                        self.device_id,
826                        &AudioObjectPropertyAddress {
827                            mSelector: coreaudio::sys::kAudioDevicePropertyStreamFormat,
828                            mScope: if self.input {
829                                coreaudio::sys::kAudioObjectPropertyScopeInput
830                            } else {
831                                coreaudio::sys::kAudioObjectPropertyScopeOutput
832                            },
833                            mElement: kAudioObjectPropertyElementMaster,
834                        },
835                        Some(property_listener_handler_shim),
836                        &*self.callback as *const _ as *mut _,
837                    );
838                }
839            }
840        }
841    }
842
843    impl futures::Stream for CoreAudioDefaultDeviceChangeListener {
844        type Item = ();
845
846        fn poll_next(
847            mut self: std::pin::Pin<&mut Self>,
848            cx: &mut std::task::Context<'_>,
849        ) -> std::task::Poll<Option<Self::Item>> {
850            self.rx.poll_next_unpin(cx)
851        }
852    }
853}
854
855#[cfg(target_os = "macos")]
856type DeviceChangeListener = macos::CoreAudioDefaultDeviceChangeListener;
857
858#[cfg(not(target_os = "macos"))]
859mod noop_change_listener {
860    use std::task::Poll;
861
862    use super::DeviceChangeListenerApi;
863
864    pub struct NoopOutputDeviceChangelistener {}
865
866    impl DeviceChangeListenerApi for NoopOutputDeviceChangelistener {
867        fn new(_input: bool) -> anyhow::Result<Self> {
868            Ok(NoopOutputDeviceChangelistener {})
869        }
870    }
871
872    impl futures::Stream for NoopOutputDeviceChangelistener {
873        type Item = ();
874
875        fn poll_next(
876            self: std::pin::Pin<&mut Self>,
877            _cx: &mut std::task::Context<'_>,
878        ) -> Poll<Option<Self::Item>> {
879            Poll::Pending
880        }
881    }
882}
883
884#[cfg(not(target_os = "macos"))]
885type DeviceChangeListener = noop_change_listener::NoopOutputDeviceChangelistener;