WIP: Start converting H264 samples to Annex-B NALs

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/media/build.rs     |   1 
crates/media/src/media.rs | 117 ++++++++++++++++++++++++++++++++++++++--
2 files changed, 110 insertions(+), 8 deletions(-)

Detailed changes

crates/media/build.rs 🔗

@@ -26,6 +26,7 @@ fn main() {
         .allowlist_var("VTEncodeInfoFlags_.*")
         .allowlist_var("kCMVideoCodecType_.*")
         .allowlist_var("kCMTime.*")
+        .allowlist_var("kCMSampleAttachmentKey_.*")
         .parse_callbacks(Box::new(bindgen::CargoCallbacks))
         .layout_tests(false)
         .generate()

crates/media/src/media.rs 🔗

@@ -208,7 +208,7 @@ pub mod core_media {
         impl_CFTypeDescription, impl_TCFType,
         string::CFString,
     };
-    use std::ffi::c_void;
+    use std::{ffi::c_void, ptr};
 
     #[repr(C)]
     pub struct __CMSampleBuffer(c_void);
@@ -261,6 +261,14 @@ pub mod core_media {
                 }
             }
         }
+
+        pub fn format_description(&self) -> CMFormatDescription {
+            unsafe {
+                CMFormatDescription::wrap_under_get_rule(CMSampleBufferGetFormatDescription(
+                    self.as_concrete_TypeRef(),
+                ))
+            }
+        }
     }
 
     #[link(name = "CoreMedia", kind = "framework")]
@@ -276,6 +284,71 @@ pub mod core_media {
             index: CMItemIndex,
             timing_info_out: *mut CMSampleTimingInfo,
         ) -> OSStatus;
+        fn CMSampleBufferGetFormatDescription(buffer: CMSampleBufferRef) -> CMFormatDescriptionRef;
+    }
+
+    #[repr(C)]
+    pub struct __CMFormatDescription(c_void);
+    // The ref type must be a pointer to the underlying struct.
+    pub type CMFormatDescriptionRef = *const __CMFormatDescription;
+
+    declare_TCFType!(CMFormatDescription, CMFormatDescriptionRef);
+    impl_TCFType!(
+        CMFormatDescription,
+        CMFormatDescriptionRef,
+        CMFormatDescriptionGetTypeID
+    );
+    impl_CFTypeDescription!(CMFormatDescription);
+
+    impl CMFormatDescription {
+        pub fn h264_parameter_set_count(&self) -> usize {
+            unsafe {
+                let mut count = 0;
+                let result = CMVideoFormatDescriptionGetH264ParameterSetAtIndex(
+                    self.as_concrete_TypeRef(),
+                    0,
+                    ptr::null_mut(),
+                    ptr::null_mut(),
+                    &mut count,
+                    ptr::null_mut(),
+                );
+                assert_eq!(result, 0);
+                count
+            }
+        }
+
+        pub fn h264_parameter_set_at_index(&self, index: usize) -> Result<&[u8]> {
+            unsafe {
+                let mut bytes = ptr::null();
+                let mut len = 0;
+                let result = CMVideoFormatDescriptionGetH264ParameterSetAtIndex(
+                    self.as_concrete_TypeRef(),
+                    index,
+                    &mut bytes,
+                    &mut len,
+                    ptr::null_mut(),
+                    ptr::null_mut(),
+                );
+                if result == 0 {
+                    Ok(std::slice::from_raw_parts(bytes, len))
+                } else {
+                    Err(anyhow!("error getting parameter set, code: {}", result))
+                }
+            }
+        }
+    }
+
+    #[link(name = "CoreMedia", kind = "framework")]
+    extern "C" {
+        fn CMFormatDescriptionGetTypeID() -> CFTypeID;
+        fn CMVideoFormatDescriptionGetH264ParameterSetAtIndex(
+            video_desc: CMFormatDescriptionRef,
+            parameter_set_index: usize,
+            parameter_set_pointer_out: *mut *const u8,
+            parameter_set_size_out: *mut usize,
+            parameter_set_count_out: *mut usize,
+            NALUnitHeaderLengthOut: *mut isize,
+        ) -> OSStatus;
     }
 }
 
@@ -284,15 +357,17 @@ pub mod video_toolbox {
 
     use super::*;
     use crate::{
-        core_media::{CMSampleBufferRef, CMTime, CMVideoCodecType},
+        core_media::{CMSampleBuffer, CMSampleBufferRef, CMTime, CMVideoCodecType},
         core_video::CVImageBufferRef,
     };
     use anyhow::{anyhow, Result};
     use bindings::VTEncodeInfoFlags;
     use core_foundation::{
         base::OSStatus,
-        dictionary::{CFDictionary, CFDictionaryRef, CFMutableDictionary},
+        dictionary::CFDictionaryRef,
         mach_port::CFAllocatorRef,
+        number::{CFBooleanGetValue, CFBooleanRef},
+        string::CFStringRef,
     };
     use std::ptr;
 
@@ -343,13 +418,39 @@ pub mod video_toolbox {
             }
         }
 
-        extern "C" fn output(
-            outputCallbackRefCon: *mut c_void,
-            sourceFrameRefCon: *mut c_void,
+        unsafe extern "C" fn output(
+            output_callback_ref_con: *mut c_void,
+            source_frame_ref_con: *mut c_void,
             status: OSStatus,
-            infoFlags: VTEncodeInfoFlags,
-            sampleBuffer: CMSampleBufferRef,
+            info_flags: VTEncodeInfoFlags,
+            sample_buffer: CMSampleBufferRef,
         ) {
+            if status != 0 {
+                println!("error encoding frame, code: {}", status);
+                return;
+            }
+            let sample_buffer = CMSampleBuffer::wrap_under_get_rule(sample_buffer);
+
+            let mut is_iframe = false;
+            let attachments = sample_buffer.attachments();
+            if let Some(attachments) = attachments.first() {
+                is_iframe = attachments
+                    .find(bindings::kCMSampleAttachmentKey_NotSync as CFStringRef)
+                    .map_or(true, |not_sync| {
+                        CFBooleanGetValue(*not_sync as CFBooleanRef)
+                    });
+            }
+
+            const START_CODE: [u8; 4] = [0x00, 0x00, 0x00, 0x01];
+            if is_iframe {
+                let format_description = sample_buffer.format_description();
+                for ix in 0..format_description.h264_parameter_set_count() {
+                    let parameter_set = format_description.h264_parameter_set_at_index(ix);
+                    stream.extend(START_CODE);
+                    stream.extend(parameter_set);
+                }
+            }
+
             println!("YO!");
         }