playback.rs

  1use anyhow::{Context as _, Result};
  2
  3use cpal::traits::{DeviceTrait, HostTrait, StreamTrait as _};
  4use futures::channel::mpsc::UnboundedSender;
  5use futures::{Stream, StreamExt as _};
  6use gpui::{
  7    BackgroundExecutor, ScreenCaptureFrame, ScreenCaptureSource, ScreenCaptureStream, Task,
  8};
  9use libwebrtc::native::{apm, audio_mixer, audio_resampler};
 10use livekit::track;
 11
 12use livekit::webrtc::{
 13    audio_frame::AudioFrame,
 14    audio_source::{AudioSourceOptions, RtcAudioSource, native::NativeAudioSource},
 15    audio_stream::native::NativeAudioStream,
 16    video_frame::{VideoBuffer, VideoFrame, VideoRotation},
 17    video_source::{RtcVideoSource, VideoResolution, native::NativeVideoSource},
 18    video_stream::native::NativeVideoStream,
 19};
 20use parking_lot::Mutex;
 21use std::cell::RefCell;
 22use std::sync::Weak;
 23use std::sync::atomic::{self, AtomicI32};
 24use std::time::Duration;
 25use std::{borrow::Cow, collections::VecDeque, sync::Arc, thread};
 26use util::{ResultExt as _, maybe};
 27
 28pub(crate) struct AudioStack {
 29    executor: BackgroundExecutor,
 30    apm: Arc<Mutex<apm::AudioProcessingModule>>,
 31    mixer: Arc<Mutex<audio_mixer::AudioMixer>>,
 32    _output_task: RefCell<Weak<Task<()>>>,
 33    next_ssrc: AtomicI32,
 34}
 35
 36// NOTE: We use WebRTC's mixer which only supports
 37// 16kHz, 32kHz and 48kHz. As 48 is the most common "next step up"
 38// for audio output devices like speakers/bluetooth, we just hard-code
 39// this; and downsample when we need to.
 40const SAMPLE_RATE: u32 = 48000;
 41const NUM_CHANNELS: u32 = 2;
 42
 43impl AudioStack {
 44    pub(crate) fn new(executor: BackgroundExecutor) -> Self {
 45        let apm = Arc::new(Mutex::new(apm::AudioProcessingModule::new(
 46            true, true, true, true,
 47        )));
 48        let mixer = Arc::new(Mutex::new(audio_mixer::AudioMixer::new()));
 49        Self {
 50            executor,
 51            apm,
 52            mixer,
 53            _output_task: RefCell::new(Weak::new()),
 54            next_ssrc: AtomicI32::new(1),
 55        }
 56    }
 57
 58    pub(crate) fn play_remote_audio_track(
 59        &self,
 60        track: &livekit::track::RemoteAudioTrack,
 61    ) -> AudioStream {
 62        let output_task = self.start_output();
 63
 64        let next_ssrc = self.next_ssrc.fetch_add(1, atomic::Ordering::Relaxed);
 65        let source = AudioMixerSource {
 66            ssrc: next_ssrc,
 67            sample_rate: SAMPLE_RATE,
 68            num_channels: NUM_CHANNELS,
 69            buffer: Arc::default(),
 70        };
 71        self.mixer.lock().add_source(source.clone());
 72
 73        let mut stream = NativeAudioStream::new(
 74            track.rtc_track(),
 75            source.sample_rate as i32,
 76            source.num_channels as i32,
 77        );
 78
 79        let receive_task = self.executor.spawn({
 80            let source = source.clone();
 81            async move {
 82                while let Some(frame) = stream.next().await {
 83                    source.receive(frame);
 84                }
 85            }
 86        });
 87
 88        let mixer = self.mixer.clone();
 89        let on_drop = util::defer(move || {
 90            mixer.lock().remove_source(source.ssrc);
 91            drop(receive_task);
 92            drop(output_task);
 93        });
 94
 95        AudioStream::Output {
 96            _drop: Box::new(on_drop),
 97        }
 98    }
 99
100    pub(crate) fn capture_local_microphone_track(
101        &self,
102    ) -> Result<(crate::LocalAudioTrack, AudioStream)> {
103        let source = NativeAudioSource::new(
104            // n.b. this struct's options are always ignored, noise cancellation is provided by apm.
105            AudioSourceOptions::default(),
106            SAMPLE_RATE,
107            NUM_CHANNELS,
108            10,
109        );
110
111        let track = track::LocalAudioTrack::create_audio_track(
112            "microphone",
113            RtcAudioSource::Native(source.clone()),
114        );
115
116        let apm = self.apm.clone();
117
118        let (frame_tx, mut frame_rx) = futures::channel::mpsc::unbounded();
119        let transmit_task = self.executor.spawn({
120            let source = source.clone();
121            async move {
122                while let Some(frame) = frame_rx.next().await {
123                    source.capture_frame(&frame).await.log_err();
124                }
125            }
126        });
127        let capture_task = self.executor.spawn(async move {
128            Self::capture_input(apm, frame_tx, SAMPLE_RATE, NUM_CHANNELS).await
129        });
130
131        let on_drop = util::defer(|| {
132            drop(transmit_task);
133            drop(capture_task);
134        });
135        return Ok((
136            super::LocalAudioTrack(track),
137            AudioStream::Output {
138                _drop: Box::new(on_drop),
139            },
140        ));
141    }
142
143    fn start_output(&self) -> Arc<Task<()>> {
144        if let Some(task) = self._output_task.borrow().upgrade() {
145            return task;
146        }
147        let task = Arc::new(self.executor.spawn({
148            let apm = self.apm.clone();
149            let mixer = self.mixer.clone();
150            async move {
151                Self::play_output(apm, mixer, SAMPLE_RATE, NUM_CHANNELS)
152                    .await
153                    .log_err();
154            }
155        }));
156        *self._output_task.borrow_mut() = Arc::downgrade(&task);
157        task
158    }
159
160    async fn play_output(
161        apm: Arc<Mutex<apm::AudioProcessingModule>>,
162        mixer: Arc<Mutex<audio_mixer::AudioMixer>>,
163        sample_rate: u32,
164        num_channels: u32,
165    ) -> Result<()> {
166        loop {
167            let mut device_change_listener = DeviceChangeListener::new(false)?;
168            let (output_device, output_config) = default_device(false)?;
169            let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>();
170            let mixer = mixer.clone();
171            let apm = apm.clone();
172            let mut resampler = audio_resampler::AudioResampler::default();
173            let mut buf = Vec::new();
174
175            thread::spawn(move || {
176                let output_stream = output_device.build_output_stream(
177                    &output_config.config(),
178                    {
179                        move |mut data, _info| {
180                            while data.len() > 0 {
181                                if data.len() <= buf.len() {
182                                    let rest = buf.split_off(data.len());
183                                    data.copy_from_slice(&buf);
184                                    buf = rest;
185                                    return;
186                                }
187                                if buf.len() > 0 {
188                                    let (prefix, suffix) = data.split_at_mut(buf.len());
189                                    prefix.copy_from_slice(&buf);
190                                    data = suffix;
191                                }
192
193                                let mut mixer = mixer.lock();
194                                let mixed = mixer.mix(output_config.channels() as usize);
195                                let sampled = resampler.remix_and_resample(
196                                    mixed,
197                                    sample_rate / 100,
198                                    num_channels,
199                                    sample_rate,
200                                    output_config.channels() as u32,
201                                    output_config.sample_rate().0,
202                                );
203                                buf = sampled.to_vec();
204                                apm.lock()
205                                    .process_reverse_stream(
206                                        &mut buf,
207                                        output_config.sample_rate().0 as i32,
208                                        output_config.channels() as i32,
209                                    )
210                                    .ok();
211                            }
212                        }
213                    },
214                    |error| log::error!("error playing audio track: {:?}", error),
215                    Some(Duration::from_millis(100)),
216                );
217
218                let Some(output_stream) = output_stream.log_err() else {
219                    return;
220                };
221
222                output_stream.play().log_err();
223                // Block forever to keep the output stream alive
224                end_on_drop_rx.recv().ok();
225            });
226
227            device_change_listener.next().await;
228            drop(end_on_drop_tx)
229        }
230    }
231
232    async fn capture_input(
233        apm: Arc<Mutex<apm::AudioProcessingModule>>,
234        frame_tx: UnboundedSender<AudioFrame<'static>>,
235        sample_rate: u32,
236        num_channels: u32,
237    ) -> Result<()> {
238        loop {
239            let mut device_change_listener = DeviceChangeListener::new(true)?;
240            let (device, config) = default_device(true)?;
241            let (end_on_drop_tx, end_on_drop_rx) = std::sync::mpsc::channel::<()>();
242            let apm = apm.clone();
243            let frame_tx = frame_tx.clone();
244            let mut resampler = audio_resampler::AudioResampler::default();
245
246            thread::spawn(move || {
247                maybe!({
248                    if let Some(name) = device.name().ok() {
249                        log::info!("Using microphone: {}", name)
250                    } else {
251                        log::info!("Using microphone: <unknown>");
252                    }
253
254                    let ten_ms_buffer_size =
255                        (config.channels() as u32 * config.sample_rate().0 / 100) as usize;
256                    let mut buf: Vec<i16> = Vec::with_capacity(ten_ms_buffer_size);
257
258                    let stream = device
259                        .build_input_stream_raw(
260                            &config.config(),
261                            cpal::SampleFormat::I16,
262                            move |data, _: &_| {
263                                let mut data = data.as_slice::<i16>().unwrap();
264                                while data.len() > 0 {
265                                    let remainder = (buf.capacity() - buf.len()).min(data.len());
266                                    buf.extend_from_slice(&data[..remainder]);
267                                    data = &data[remainder..];
268
269                                    if buf.capacity() == buf.len() {
270                                        let mut sampled = resampler
271                                            .remix_and_resample(
272                                                buf.as_slice(),
273                                                config.sample_rate().0 / 100,
274                                                config.channels() as u32,
275                                                config.sample_rate().0,
276                                                num_channels,
277                                                sample_rate,
278                                            )
279                                            .to_owned();
280                                        apm.lock()
281                                            .process_stream(
282                                                &mut sampled,
283                                                sample_rate as i32,
284                                                num_channels as i32,
285                                            )
286                                            .log_err();
287                                        buf.clear();
288                                        frame_tx
289                                            .unbounded_send(AudioFrame {
290                                                data: Cow::Owned(sampled),
291                                                sample_rate,
292                                                num_channels,
293                                                samples_per_channel: sample_rate / 100,
294                                            })
295                                            .ok();
296                                    }
297                                }
298                            },
299                            |err| log::error!("error capturing audio track: {:?}", err),
300                            Some(Duration::from_millis(100)),
301                        )
302                        .context("failed to build input stream")?;
303
304                    stream.play()?;
305                    // Keep the thread alive and holding onto the `stream`
306                    end_on_drop_rx.recv().ok();
307                    anyhow::Ok(Some(()))
308                })
309                .log_err();
310            });
311
312            device_change_listener.next().await;
313            drop(end_on_drop_tx)
314        }
315    }
316}
317
318use super::LocalVideoTrack;
319
320pub enum AudioStream {
321    Input { _task: Task<()> },
322    Output { _drop: Box<dyn std::any::Any> },
323}
324
325pub(crate) async fn capture_local_video_track(
326    capture_source: &dyn ScreenCaptureSource,
327    cx: &mut gpui::AsyncApp,
328) -> Result<(crate::LocalVideoTrack, Box<dyn ScreenCaptureStream>)> {
329    let metadata = capture_source.metadata()?;
330    let track_source = gpui_tokio::Tokio::spawn(cx, async move {
331        NativeVideoSource::new(VideoResolution {
332            width: metadata.resolution.width.0 as u32,
333            height: metadata.resolution.height.0 as u32,
334        })
335    })?
336    .await?;
337
338    let capture_stream = capture_source
339        .stream(cx.foreground_executor(), {
340            let track_source = track_source.clone();
341            Box::new(move |frame| {
342                if let Some(buffer) = video_frame_buffer_to_webrtc(frame) {
343                    track_source.capture_frame(&VideoFrame {
344                        rotation: VideoRotation::VideoRotation0,
345                        timestamp_us: 0,
346                        buffer,
347                    });
348                }
349            })
350        })
351        .await??;
352
353    Ok((
354        LocalVideoTrack(track::LocalVideoTrack::create_video_track(
355            "screen share",
356            RtcVideoSource::Native(track_source),
357        )),
358        capture_stream,
359    ))
360}
361
362fn default_device(input: bool) -> Result<(cpal::Device, cpal::SupportedStreamConfig)> {
363    let device;
364    let config;
365    if input {
366        device = cpal::default_host()
367            .default_input_device()
368            .context("no audio input device available")?;
369        config = device
370            .default_input_config()
371            .context("failed to get default input config")?;
372    } else {
373        device = cpal::default_host()
374            .default_output_device()
375            .context("no audio output device available")?;
376        config = device
377            .default_output_config()
378            .context("failed to get default output config")?;
379    }
380    Ok((device, config))
381}
382
383#[derive(Clone)]
384struct AudioMixerSource {
385    ssrc: i32,
386    sample_rate: u32,
387    num_channels: u32,
388    buffer: Arc<Mutex<VecDeque<Vec<i16>>>>,
389}
390
391impl AudioMixerSource {
392    fn receive(&self, frame: AudioFrame) {
393        assert_eq!(
394            frame.data.len() as u32,
395            self.sample_rate * self.num_channels / 100
396        );
397
398        let mut buffer = self.buffer.lock();
399        buffer.push_back(frame.data.to_vec());
400        while buffer.len() > 10 {
401            buffer.pop_front();
402        }
403    }
404}
405
406impl libwebrtc::native::audio_mixer::AudioMixerSource for AudioMixerSource {
407    fn ssrc(&self) -> i32 {
408        self.ssrc
409    }
410
411    fn preferred_sample_rate(&self) -> u32 {
412        self.sample_rate
413    }
414
415    fn get_audio_frame_with_info<'a>(&self, target_sample_rate: u32) -> Option<AudioFrame<'_>> {
416        assert_eq!(self.sample_rate, target_sample_rate);
417        let buf = self.buffer.lock().pop_front()?;
418        Some(AudioFrame {
419            data: Cow::Owned(buf),
420            sample_rate: self.sample_rate,
421            num_channels: self.num_channels,
422            samples_per_channel: self.sample_rate / 100,
423        })
424    }
425}
426
427pub fn play_remote_video_track(
428    track: &crate::RemoteVideoTrack,
429) -> impl Stream<Item = RemoteVideoFrame> + use<> {
430    #[cfg(target_os = "macos")]
431    {
432        let mut pool = None;
433        let most_recent_frame_size = (0, 0);
434        NativeVideoStream::new(track.0.rtc_track()).filter_map(move |frame| {
435            if pool == None
436                || most_recent_frame_size != (frame.buffer.width(), frame.buffer.height())
437            {
438                pool = create_buffer_pool(frame.buffer.width(), frame.buffer.height()).log_err();
439            }
440            let pool = pool.clone();
441            async move {
442                if frame.buffer.width() < 10 && frame.buffer.height() < 10 {
443                    // when the remote stops sharing, we get an 8x8 black image.
444                    // In a lil bit, the unpublish will come through and close the view,
445                    // but until then, don't flash black.
446                    return None;
447                }
448
449                video_frame_buffer_from_webrtc(pool?, frame.buffer)
450            }
451        })
452    }
453    #[cfg(not(target_os = "macos"))]
454    {
455        NativeVideoStream::new(track.0.rtc_track())
456            .filter_map(|frame| async move { video_frame_buffer_from_webrtc(frame.buffer) })
457    }
458}
459
460#[cfg(target_os = "macos")]
461fn create_buffer_pool(
462    width: u32,
463    height: u32,
464) -> Result<core_video::pixel_buffer_pool::CVPixelBufferPool> {
465    use core_foundation::{base::TCFType, number::CFNumber, string::CFString};
466    use core_video::pixel_buffer;
467    use core_video::{
468        pixel_buffer::kCVPixelFormatType_420YpCbCr8BiPlanarFullRange,
469        pixel_buffer_io_surface::kCVPixelBufferIOSurfaceCoreAnimationCompatibilityKey,
470        pixel_buffer_pool::{self},
471    };
472
473    let width_key: CFString =
474        unsafe { CFString::wrap_under_get_rule(pixel_buffer::kCVPixelBufferWidthKey) };
475    let height_key: CFString =
476        unsafe { CFString::wrap_under_get_rule(pixel_buffer::kCVPixelBufferHeightKey) };
477    let animation_key: CFString = unsafe {
478        CFString::wrap_under_get_rule(kCVPixelBufferIOSurfaceCoreAnimationCompatibilityKey)
479    };
480    let format_key: CFString =
481        unsafe { CFString::wrap_under_get_rule(pixel_buffer::kCVPixelBufferPixelFormatTypeKey) };
482
483    let yes: CFNumber = 1.into();
484    let width: CFNumber = (width as i32).into();
485    let height: CFNumber = (height as i32).into();
486    let format: CFNumber = (kCVPixelFormatType_420YpCbCr8BiPlanarFullRange as i64).into();
487
488    let buffer_attributes = core_foundation::dictionary::CFDictionary::from_CFType_pairs(&[
489        (width_key, width.into_CFType()),
490        (height_key, height.into_CFType()),
491        (animation_key, yes.into_CFType()),
492        (format_key, format.into_CFType()),
493    ]);
494
495    pixel_buffer_pool::CVPixelBufferPool::new(None, Some(&buffer_attributes)).map_err(|cv_return| {
496        anyhow::anyhow!("failed to create pixel buffer pool: CVReturn({cv_return})",)
497    })
498}
499
500#[cfg(target_os = "macos")]
501pub type RemoteVideoFrame = core_video::pixel_buffer::CVPixelBuffer;
502
503#[cfg(target_os = "macos")]
504fn video_frame_buffer_from_webrtc(
505    pool: core_video::pixel_buffer_pool::CVPixelBufferPool,
506    buffer: Box<dyn VideoBuffer>,
507) -> Option<RemoteVideoFrame> {
508    use core_foundation::base::TCFType;
509    use core_video::{pixel_buffer::CVPixelBuffer, r#return::kCVReturnSuccess};
510    use livekit::webrtc::native::yuv_helper::i420_to_nv12;
511
512    if let Some(native) = buffer.as_native() {
513        let pixel_buffer = native.get_cv_pixel_buffer();
514        if pixel_buffer.is_null() {
515            return None;
516        }
517        return unsafe { Some(CVPixelBuffer::wrap_under_get_rule(pixel_buffer as _)) };
518    }
519
520    let i420_buffer = buffer.as_i420()?;
521    let pixel_buffer = pool.create_pixel_buffer().log_err()?;
522
523    let image_buffer = unsafe {
524        if pixel_buffer.lock_base_address(0) != kCVReturnSuccess {
525            return None;
526        }
527
528        let dst_y = pixel_buffer.get_base_address_of_plane(0);
529        let dst_y_stride = pixel_buffer.get_bytes_per_row_of_plane(0);
530        let dst_y_len = pixel_buffer.get_height_of_plane(0) * dst_y_stride;
531        let dst_uv = pixel_buffer.get_base_address_of_plane(1);
532        let dst_uv_stride = pixel_buffer.get_bytes_per_row_of_plane(1);
533        let dst_uv_len = pixel_buffer.get_height_of_plane(1) * dst_uv_stride;
534        let width = pixel_buffer.get_width();
535        let height = pixel_buffer.get_height();
536        let dst_y_buffer = std::slice::from_raw_parts_mut(dst_y as *mut u8, dst_y_len);
537        let dst_uv_buffer = std::slice::from_raw_parts_mut(dst_uv as *mut u8, dst_uv_len);
538
539        let (stride_y, stride_u, stride_v) = i420_buffer.strides();
540        let (src_y, src_u, src_v) = i420_buffer.data();
541        i420_to_nv12(
542            src_y,
543            stride_y,
544            src_u,
545            stride_u,
546            src_v,
547            stride_v,
548            dst_y_buffer,
549            dst_y_stride as u32,
550            dst_uv_buffer,
551            dst_uv_stride as u32,
552            width as i32,
553            height as i32,
554        );
555
556        if pixel_buffer.unlock_base_address(0) != kCVReturnSuccess {
557            return None;
558        }
559
560        pixel_buffer
561    };
562
563    Some(image_buffer)
564}
565
566#[cfg(not(target_os = "macos"))]
567pub type RemoteVideoFrame = Arc<gpui::RenderImage>;
568
569#[cfg(not(target_os = "macos"))]
570fn video_frame_buffer_from_webrtc(buffer: Box<dyn VideoBuffer>) -> Option<RemoteVideoFrame> {
571    use gpui::RenderImage;
572    use image::{Frame, RgbaImage};
573    use livekit::webrtc::prelude::VideoFormatType;
574    use smallvec::SmallVec;
575    use std::alloc::{Layout, alloc};
576
577    let width = buffer.width();
578    let height = buffer.height();
579    let stride = width * 4;
580    let byte_len = (stride * height) as usize;
581    let argb_image = unsafe {
582        // Motivation for this unsafe code is to avoid initializing the frame data, since to_argb
583        // will write all bytes anyway.
584        let start_ptr = alloc(Layout::array::<u8>(byte_len).log_err()?);
585        if start_ptr.is_null() {
586            return None;
587        }
588        let argb_frame_slice = std::slice::from_raw_parts_mut(start_ptr, byte_len);
589        buffer.to_argb(
590            VideoFormatType::ARGB,
591            argb_frame_slice,
592            stride,
593            width as i32,
594            height as i32,
595        );
596        Vec::from_raw_parts(start_ptr, byte_len, byte_len)
597    };
598
599    // TODO: Unclear why providing argb_image to RgbaImage works properly.
600    let image = RgbaImage::from_raw(width, height, argb_image)
601        .with_context(|| "Bug: not enough bytes allocated for image.")
602        .log_err()?;
603
604    Some(Arc::new(RenderImage::new(SmallVec::from_elem(
605        Frame::new(image),
606        1,
607    ))))
608}
609
610#[cfg(target_os = "macos")]
611fn video_frame_buffer_to_webrtc(frame: ScreenCaptureFrame) -> Option<impl AsRef<dyn VideoBuffer>> {
612    use livekit::webrtc;
613
614    let pixel_buffer = frame.0.as_concrete_TypeRef();
615    std::mem::forget(frame.0);
616    unsafe {
617        Some(webrtc::video_frame::native::NativeBuffer::from_cv_pixel_buffer(pixel_buffer as _))
618    }
619}
620
621#[cfg(not(target_os = "macos"))]
622fn video_frame_buffer_to_webrtc(frame: ScreenCaptureFrame) -> Option<impl AsRef<dyn VideoBuffer>> {
623    use libwebrtc::native::yuv_helper::{abgr_to_nv12, argb_to_nv12};
624    use livekit::webrtc::prelude::NV12Buffer;
625    match frame.0 {
626        scap::frame::Frame::BGRx(frame) => {
627            let mut buffer = NV12Buffer::new(frame.width as u32, frame.height as u32);
628            let (stride_y, stride_uv) = buffer.strides();
629            let (data_y, data_uv) = buffer.data_mut();
630            argb_to_nv12(
631                &frame.data,
632                frame.width as u32 * 4,
633                data_y,
634                stride_y,
635                data_uv,
636                stride_uv,
637                frame.width,
638                frame.height,
639            );
640            Some(buffer)
641        }
642        scap::frame::Frame::RGBx(frame) => {
643            let mut buffer = NV12Buffer::new(frame.width as u32, frame.height as u32);
644            let (stride_y, stride_uv) = buffer.strides();
645            let (data_y, data_uv) = buffer.data_mut();
646            abgr_to_nv12(
647                &frame.data,
648                frame.width as u32 * 4,
649                data_y,
650                stride_y,
651                data_uv,
652                stride_uv,
653                frame.width,
654                frame.height,
655            );
656            Some(buffer)
657        }
658        scap::frame::Frame::YUVFrame(yuvframe) => {
659            let mut buffer = NV12Buffer::with_strides(
660                yuvframe.width as u32,
661                yuvframe.height as u32,
662                yuvframe.luminance_stride as u32,
663                yuvframe.chrominance_stride as u32,
664            );
665            let (luminance, chrominance) = buffer.data_mut();
666            luminance.copy_from_slice(yuvframe.luminance_bytes.as_slice());
667            chrominance.copy_from_slice(yuvframe.chrominance_bytes.as_slice());
668            Some(buffer)
669        }
670        _ => {
671            log::error!(
672                "Expected BGRx or YUV frame from scap screen capture but got some other format."
673            );
674            None
675        }
676    }
677}
678
679trait DeviceChangeListenerApi: Stream<Item = ()> + Sized {
680    fn new(input: bool) -> Result<Self>;
681}
682
683#[cfg(target_os = "macos")]
684mod macos {
685
686    use coreaudio::sys::{
687        AudioObjectAddPropertyListener, AudioObjectID, AudioObjectPropertyAddress,
688        AudioObjectRemovePropertyListener, OSStatus, kAudioHardwarePropertyDefaultInputDevice,
689        kAudioHardwarePropertyDefaultOutputDevice, kAudioObjectPropertyElementMaster,
690        kAudioObjectPropertyScopeGlobal, kAudioObjectSystemObject,
691    };
692    use futures::{StreamExt, channel::mpsc::UnboundedReceiver};
693
694    /// Implementation from: https://github.com/zed-industries/cpal/blob/fd8bc2fd39f1f5fdee5a0690656caff9a26d9d50/src/host/coreaudio/macos/property_listener.rs#L15
695    pub struct CoreAudioDefaultDeviceChangeListener {
696        rx: UnboundedReceiver<()>,
697        callback: Box<PropertyListenerCallbackWrapper>,
698        input: bool,
699        device_id: AudioObjectID, // Store the device ID to properly remove listeners
700    }
701
702    trait _AssertSend: Send {}
703    impl _AssertSend for CoreAudioDefaultDeviceChangeListener {}
704
705    struct PropertyListenerCallbackWrapper(Box<dyn FnMut() + Send>);
706
707    unsafe extern "C" fn property_listener_handler_shim(
708        _: AudioObjectID,
709        _: u32,
710        _: *const AudioObjectPropertyAddress,
711        callback: *mut ::std::os::raw::c_void,
712    ) -> OSStatus {
713        let wrapper = callback as *mut PropertyListenerCallbackWrapper;
714        unsafe { (*wrapper).0() };
715        0
716    }
717
718    impl super::DeviceChangeListenerApi for CoreAudioDefaultDeviceChangeListener {
719        fn new(input: bool) -> anyhow::Result<Self> {
720            let (tx, rx) = futures::channel::mpsc::unbounded();
721
722            let callback = Box::new(PropertyListenerCallbackWrapper(Box::new(move || {
723                tx.unbounded_send(()).ok();
724            })));
725
726            // Get the current default device ID
727            let device_id = unsafe {
728                // Listen for default device changes
729                coreaudio::Error::from_os_status(AudioObjectAddPropertyListener(
730                    kAudioObjectSystemObject,
731                    &AudioObjectPropertyAddress {
732                        mSelector: if input {
733                            kAudioHardwarePropertyDefaultInputDevice
734                        } else {
735                            kAudioHardwarePropertyDefaultOutputDevice
736                        },
737                        mScope: kAudioObjectPropertyScopeGlobal,
738                        mElement: kAudioObjectPropertyElementMaster,
739                    },
740                    Some(property_listener_handler_shim),
741                    &*callback as *const _ as *mut _,
742                ))?;
743
744                // Also listen for changes to the device configuration
745                let device_id = if input {
746                    let mut input_device: AudioObjectID = 0;
747                    let mut prop_size = std::mem::size_of::<AudioObjectID>() as u32;
748                    let result = coreaudio::sys::AudioObjectGetPropertyData(
749                        kAudioObjectSystemObject,
750                        &AudioObjectPropertyAddress {
751                            mSelector: kAudioHardwarePropertyDefaultInputDevice,
752                            mScope: kAudioObjectPropertyScopeGlobal,
753                            mElement: kAudioObjectPropertyElementMaster,
754                        },
755                        0,
756                        std::ptr::null(),
757                        &mut prop_size as *mut _,
758                        &mut input_device as *mut _ as *mut _,
759                    );
760                    if result != 0 {
761                        log::warn!("Failed to get default input device ID");
762                        0
763                    } else {
764                        input_device
765                    }
766                } else {
767                    let mut output_device: AudioObjectID = 0;
768                    let mut prop_size = std::mem::size_of::<AudioObjectID>() as u32;
769                    let result = coreaudio::sys::AudioObjectGetPropertyData(
770                        kAudioObjectSystemObject,
771                        &AudioObjectPropertyAddress {
772                            mSelector: kAudioHardwarePropertyDefaultOutputDevice,
773                            mScope: kAudioObjectPropertyScopeGlobal,
774                            mElement: kAudioObjectPropertyElementMaster,
775                        },
776                        0,
777                        std::ptr::null(),
778                        &mut prop_size as *mut _,
779                        &mut output_device as *mut _ as *mut _,
780                    );
781                    if result != 0 {
782                        log::warn!("Failed to get default output device ID");
783                        0
784                    } else {
785                        output_device
786                    }
787                };
788
789                if device_id != 0 {
790                    // Listen for format changes on the device
791                    coreaudio::Error::from_os_status(AudioObjectAddPropertyListener(
792                        device_id,
793                        &AudioObjectPropertyAddress {
794                            mSelector: coreaudio::sys::kAudioDevicePropertyStreamFormat,
795                            mScope: if input {
796                                coreaudio::sys::kAudioObjectPropertyScopeInput
797                            } else {
798                                coreaudio::sys::kAudioObjectPropertyScopeOutput
799                            },
800                            mElement: kAudioObjectPropertyElementMaster,
801                        },
802                        Some(property_listener_handler_shim),
803                        &*callback as *const _ as *mut _,
804                    ))?;
805                }
806
807                device_id
808            };
809
810            Ok(Self {
811                rx,
812                callback,
813                input,
814                device_id,
815            })
816        }
817    }
818
819    impl Drop for CoreAudioDefaultDeviceChangeListener {
820        fn drop(&mut self) {
821            unsafe {
822                // Remove the system-level property listener
823                AudioObjectRemovePropertyListener(
824                    kAudioObjectSystemObject,
825                    &AudioObjectPropertyAddress {
826                        mSelector: if self.input {
827                            kAudioHardwarePropertyDefaultInputDevice
828                        } else {
829                            kAudioHardwarePropertyDefaultOutputDevice
830                        },
831                        mScope: kAudioObjectPropertyScopeGlobal,
832                        mElement: kAudioObjectPropertyElementMaster,
833                    },
834                    Some(property_listener_handler_shim),
835                    &*self.callback as *const _ as *mut _,
836                );
837
838                // Remove the device-specific property listener if we have a valid device ID
839                if self.device_id != 0 {
840                    AudioObjectRemovePropertyListener(
841                        self.device_id,
842                        &AudioObjectPropertyAddress {
843                            mSelector: coreaudio::sys::kAudioDevicePropertyStreamFormat,
844                            mScope: if self.input {
845                                coreaudio::sys::kAudioObjectPropertyScopeInput
846                            } else {
847                                coreaudio::sys::kAudioObjectPropertyScopeOutput
848                            },
849                            mElement: kAudioObjectPropertyElementMaster,
850                        },
851                        Some(property_listener_handler_shim),
852                        &*self.callback as *const _ as *mut _,
853                    );
854                }
855            }
856        }
857    }
858
859    impl futures::Stream for CoreAudioDefaultDeviceChangeListener {
860        type Item = ();
861
862        fn poll_next(
863            mut self: std::pin::Pin<&mut Self>,
864            cx: &mut std::task::Context<'_>,
865        ) -> std::task::Poll<Option<Self::Item>> {
866            self.rx.poll_next_unpin(cx)
867        }
868    }
869}
870
871#[cfg(target_os = "macos")]
872type DeviceChangeListener = macos::CoreAudioDefaultDeviceChangeListener;
873
874#[cfg(not(target_os = "macos"))]
875mod noop_change_listener {
876    use std::task::Poll;
877
878    use super::DeviceChangeListenerApi;
879
880    pub struct NoopOutputDeviceChangelistener {}
881
882    impl DeviceChangeListenerApi for NoopOutputDeviceChangelistener {
883        fn new(_input: bool) -> anyhow::Result<Self> {
884            Ok(NoopOutputDeviceChangelistener {})
885        }
886    }
887
888    impl futures::Stream for NoopOutputDeviceChangelistener {
889        type Item = ();
890
891        fn poll_next(
892            self: std::pin::Pin<&mut Self>,
893            _cx: &mut std::task::Context<'_>,
894        ) -> Poll<Option<Self::Item>> {
895            Poll::Pending
896        }
897    }
898}
899
900#[cfg(not(target_os = "macos"))]
901type DeviceChangeListener = noop_change_listener::NoopOutputDeviceChangelistener;