Use a safe API for Arena

Antonio Scandurra created

Change summary

crates/gpui2/src/arena.rs   | 132 +++++++++++++++++++++++++-------------
crates/gpui2/src/element.rs |  12 +-
crates/gpui2/src/window.rs  |  70 ++++++++++----------
3 files changed, 125 insertions(+), 89 deletions(-)

Detailed changes

crates/gpui2/src/arena.rs 🔗

@@ -1,40 +1,49 @@
 use std::{
     alloc,
+    cell::Cell,
+    ops::{Deref, DerefMut},
     ptr::{self, NonNull},
+    rc::Rc,
 };
 
+struct ArenaElement {
+    value: NonNull<u8>,
+    drop: unsafe fn(NonNull<u8>),
+}
+
+impl Drop for ArenaElement {
+    fn drop(&mut self) {
+        unsafe {
+            (self.drop)(self.value);
+        }
+    }
+}
+
 pub struct Arena {
     start: NonNull<u8>,
     offset: usize,
     elements: Vec<ArenaElement>,
+    valid: Rc<Cell<bool>>,
 }
 
-impl Default for Arena {
-    fn default() -> Self {
+impl Arena {
+    pub fn new(size_in_bytes: usize) -> Self {
         unsafe {
-            let layout = alloc::Layout::from_size_align(16 * 1024 * 1024, 1).unwrap();
+            let layout = alloc::Layout::from_size_align(size_in_bytes, 1).unwrap();
             let ptr = alloc::alloc(layout);
             Self {
                 start: NonNull::new_unchecked(ptr),
                 offset: 0,
                 elements: Vec::new(),
+                valid: Rc::new(Cell::new(true)),
             }
         }
     }
-}
 
-struct ArenaElement {
-    value: NonNull<u8>,
-    drop: unsafe fn(NonNull<u8>),
-}
-
-impl Arena {
     pub fn clear(&mut self) {
-        for element in self.elements.drain(..) {
-            unsafe {
-                (element.drop)(element.value);
-            }
-        }
+        self.valid.set(false);
+        self.valid = Rc::new(Cell::new(true));
+        self.elements.clear();
         self.offset = 0;
     }
 
@@ -46,42 +55,71 @@ impl Arena {
 
         unsafe {
             let layout = alloc::Layout::for_value(&value).pad_to_align();
-            let value_ptr = self.start.as_ptr().add(self.offset).cast::<T>();
-            ptr::write(value_ptr, value);
+            let ptr = NonNull::new_unchecked(self.start.as_ptr().add(self.offset).cast::<T>());
+            ptr::write(ptr.as_ptr(), value);
 
-            let value = NonNull::new_unchecked(value_ptr);
             self.elements.push(ArenaElement {
-                value: value.cast(),
+                value: ptr.cast(),
                 drop: drop::<T>,
             });
             self.offset += layout.size();
-            ArenaRef(value)
+            ArenaRef {
+                ptr,
+                valid: self.valid.clone(),
+            }
         }
     }
 }
 
-pub struct ArenaRef<T: ?Sized>(NonNull<T>);
+impl Drop for Arena {
+    fn drop(&mut self) {
+        self.clear();
+    }
+}
 
-impl<T: ?Sized> Copy for ArenaRef<T> {}
+pub struct ArenaRef<T: ?Sized> {
+    ptr: NonNull<T>,
+    valid: Rc<Cell<bool>>,
+}
 
 impl<T: ?Sized> Clone for ArenaRef<T> {
     fn clone(&self) -> Self {
-        Self(self.0)
+        Self {
+            ptr: self.ptr,
+            valid: self.valid.clone(),
+        }
     }
 }
 
 impl<T: ?Sized> ArenaRef<T> {
-    pub unsafe fn map<U: ?Sized>(mut self, f: impl FnOnce(&mut T) -> &mut U) -> ArenaRef<U> {
-        let u = f(self.get_mut());
-        ArenaRef(NonNull::new_unchecked(u))
+    pub fn map<U: ?Sized>(mut self, f: impl FnOnce(&mut T) -> &mut U) -> ArenaRef<U> {
+        ArenaRef {
+            ptr: unsafe { NonNull::new_unchecked(f(&mut *self)) },
+            valid: self.valid,
+        }
     }
 
-    pub unsafe fn get(&self) -> &T {
-        self.0.as_ref()
+    fn validate(&self) {
+        assert!(
+            self.valid.get(),
+            "attempted to dereference an ArenaRef after its Arena was cleared"
+        );
     }
+}
+
+impl<T: ?Sized> Deref for ArenaRef<T> {
+    type Target = T;
+
+    fn deref(&self) -> &Self::Target {
+        self.validate();
+        unsafe { self.ptr.as_ref() }
+    }
+}
 
-    pub unsafe fn get_mut(&mut self) -> &mut T {
-        self.0.as_mut()
+impl<T: ?Sized> DerefMut for ArenaRef<T> {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        self.validate();
+        unsafe { self.ptr.as_mut() }
     }
 }
 
@@ -93,25 +131,25 @@ mod tests {
 
     #[test]
     fn test_arena() {
-        let mut arena = Arena::default();
-        let mut a = arena.alloc(1u64);
-        let mut b = arena.alloc(2u32);
-        let mut c = arena.alloc(3u16);
-        let mut d = arena.alloc(4u8);
-        assert_eq!(unsafe { *a.get_mut() }, 1);
-        assert_eq!(unsafe { *b.get_mut() }, 2);
-        assert_eq!(unsafe { *c.get_mut() }, 3);
-        assert_eq!(unsafe { *d.get_mut() }, 4);
+        let mut arena = Arena::new(1024);
+        let a = arena.alloc(1u64);
+        let b = arena.alloc(2u32);
+        let c = arena.alloc(3u16);
+        let d = arena.alloc(4u8);
+        assert_eq!(*a, 1);
+        assert_eq!(*b, 2);
+        assert_eq!(*c, 3);
+        assert_eq!(*d, 4);
 
         arena.clear();
-        let mut a = arena.alloc(5u64);
-        let mut b = arena.alloc(6u32);
-        let mut c = arena.alloc(7u16);
-        let mut d = arena.alloc(8u8);
-        assert_eq!(unsafe { *a.get_mut() }, 5);
-        assert_eq!(unsafe { *b.get_mut() }, 6);
-        assert_eq!(unsafe { *c.get_mut() }, 7);
-        assert_eq!(unsafe { *d.get_mut() }, 8);
+        let a = arena.alloc(5u64);
+        let b = arena.alloc(6u32);
+        let c = arena.alloc(7u16);
+        let d = arena.alloc(8u8);
+        assert_eq!(*a, 5);
+        assert_eq!(*b, 6);
+        assert_eq!(*c, 7);
+        assert_eq!(*d, 8);
 
         // Ensure drop gets called.
         let dropped = Rc::new(Cell::new(false));

crates/gpui2/src/element.rs 🔗

@@ -415,16 +415,16 @@ impl AnyElement {
     {
         let element =
             FRAME_ARENA.with_borrow_mut(|arena| arena.alloc(Some(DrawableElement::new(element))));
-        let element = unsafe { element.map(|element| element as &mut dyn ElementObject) };
+        let element = element.map(|element| element as &mut dyn ElementObject);
         AnyElement(element)
     }
 
     pub fn layout(&mut self, cx: &mut WindowContext) -> LayoutId {
-        unsafe { self.0.get_mut() }.layout(cx)
+        self.0.layout(cx)
     }
 
     pub fn paint(&mut self, cx: &mut WindowContext) {
-        unsafe { self.0.get_mut() }.paint(cx)
+        self.0.paint(cx)
     }
 
     /// Initializes this element and performs layout within the given available space to determine its size.
@@ -433,7 +433,7 @@ impl AnyElement {
         available_space: Size<AvailableSpace>,
         cx: &mut WindowContext,
     ) -> Size<Pixels> {
-        unsafe { self.0.get_mut() }.measure(available_space, cx)
+        self.0.measure(available_space, cx)
     }
 
     /// Initializes this element and performs layout in the available space, then paints it at the given origin.
@@ -443,11 +443,11 @@ impl AnyElement {
         available_space: Size<AvailableSpace>,
         cx: &mut WindowContext,
     ) {
-        unsafe { self.0.get_mut() }.draw(origin, available_space, cx)
+        self.0.draw(origin, available_space, cx)
     }
 
     pub fn inner_id(&self) -> Option<ElementId> {
-        unsafe { self.0.get() }.element_id()
+        self.0.element_id()
     }
 }
 

crates/gpui2/src/window.rs 🔗

@@ -99,7 +99,7 @@ struct FocusEvent {
 slotmap::new_key_type! { pub struct FocusId; }
 
 thread_local! {
-    pub static FRAME_ARENA: RefCell<Arena> = RefCell::new(Arena::default());
+    pub static FRAME_ARENA: RefCell<Arena> = RefCell::new(Arena::new(16 * 1024 * 1024));
 }
 
 impl FocusId {
@@ -829,14 +829,15 @@ impl<'a> WindowContext<'a> {
         mut handler: impl FnMut(&Event, DispatchPhase, &mut WindowContext) + 'static,
     ) {
         let order = self.window.next_frame.z_index_stack.clone();
-        let handler = FRAME_ARENA.with_borrow_mut(|arena| {
-            arena.alloc(
-                move |event: &dyn Any, phase: DispatchPhase, cx: &mut WindowContext<'_>| {
-                    handler(event.downcast_ref().unwrap(), phase, cx)
-                },
-            )
-        });
-        let handler = unsafe { handler.map(|handler| handler as _) };
+        let handler = FRAME_ARENA
+            .with_borrow_mut(|arena| {
+                arena.alloc(
+                    move |event: &dyn Any, phase: DispatchPhase, cx: &mut WindowContext<'_>| {
+                        handler(event.downcast_ref().unwrap(), phase, cx)
+                    },
+                )
+            })
+            .map(|handler| handler as _);
         self.window
             .next_frame
             .mouse_listeners
@@ -855,14 +856,15 @@ impl<'a> WindowContext<'a> {
         &mut self,
         listener: impl Fn(&Event, DispatchPhase, &mut WindowContext) + 'static,
     ) {
-        let listener = FRAME_ARENA.with_borrow_mut(|arena| {
-            arena.alloc(move |event: &dyn Any, phase, cx: &mut WindowContext<'_>| {
-                if let Some(event) = event.downcast_ref::<Event>() {
-                    listener(event, phase, cx)
-                }
+        let listener = FRAME_ARENA
+            .with_borrow_mut(|arena| {
+                arena.alloc(move |event: &dyn Any, phase, cx: &mut WindowContext<'_>| {
+                    if let Some(event) = event.downcast_ref::<Event>() {
+                        listener(event, phase, cx)
+                    }
+                })
             })
-        });
-        let listener = unsafe { listener.map(|handler| handler as _) };
+            .map(|handler| handler as _);
         self.window.next_frame.dispatch_tree.on_key_event(listener);
     }
 
@@ -877,8 +879,9 @@ impl<'a> WindowContext<'a> {
         action_type: TypeId,
         listener: impl Fn(&dyn Any, DispatchPhase, &mut WindowContext) + 'static,
     ) {
-        let listener = FRAME_ARENA.with_borrow_mut(|arena| arena.alloc(listener));
-        let listener = unsafe { listener.map(|handler| handler as _) };
+        let listener = FRAME_ARENA
+            .with_borrow_mut(|arena| arena.alloc(listener))
+            .map(|handler| handler as _);
         self.window
             .next_frame
             .dispatch_tree
@@ -1284,14 +1287,15 @@ impl<'a> WindowContext<'a> {
             cx.with_key_dispatch(Some(KeyContext::default()), None, |_, cx| {
                 for (action_type, action_listeners) in &cx.app.global_action_listeners {
                     for action_listener in action_listeners.iter().cloned() {
-                        let listener = FRAME_ARENA.with_borrow_mut(|arena| {
-                            arena.alloc(
-                                move |action: &dyn Any, phase, cx: &mut WindowContext<'_>| {
-                                    action_listener(action, phase, cx)
-                                },
-                            )
-                        });
-                        let listener = unsafe { listener.map(|listener| listener as _) };
+                        let listener = FRAME_ARENA
+                            .with_borrow_mut(|arena| {
+                                arena.alloc(
+                                    move |action: &dyn Any, phase, cx: &mut WindowContext<'_>| {
+                                        action_listener(action, phase, cx)
+                                    },
+                                )
+                            })
+                            .map(|listener| listener as _);
                         cx.window
                             .next_frame
                             .dispatch_tree
@@ -1478,7 +1482,6 @@ impl<'a> WindowContext<'a> {
             // Capture phase, events bubble from back to front. Handlers for this phase are used for
             // special purposes, such as detecting events outside of a given Bounds.
             for (_, handler) in &mut handlers {
-                let handler = unsafe { handler.get_mut() };
                 handler(event, DispatchPhase::Capture, self);
                 if !self.app.propagate_event {
                     break;
@@ -1488,7 +1491,6 @@ impl<'a> WindowContext<'a> {
             // Bubble phase, where most normal handlers do their work.
             if self.app.propagate_event {
                 for (_, handler) in handlers.iter_mut().rev() {
-                    let handler = unsafe { handler.get_mut() };
                     handler(event, DispatchPhase::Bubble, self);
                     if !self.app.propagate_event {
                         break;
@@ -1538,8 +1540,7 @@ impl<'a> WindowContext<'a> {
                 context_stack.push(context);
             }
 
-            for mut key_listener in node.key_listeners.clone() {
-                let key_listener = unsafe { key_listener.get_mut() };
+            for key_listener in node.key_listeners.clone() {
                 key_listener(event, DispatchPhase::Capture, self);
                 if !self.propagate_event {
                     return;
@@ -1551,8 +1552,7 @@ impl<'a> WindowContext<'a> {
         for node_id in dispatch_path.iter().rev() {
             // Handle low level key events
             let node = self.window.rendered_frame.dispatch_tree.node(*node_id);
-            for mut key_listener in node.key_listeners.clone() {
-                let key_listener = unsafe { key_listener.get_mut() };
+            for key_listener in node.key_listeners.clone() {
                 key_listener(event, DispatchPhase::Bubble, self);
                 if !self.propagate_event {
                     return;
@@ -1604,12 +1604,11 @@ impl<'a> WindowContext<'a> {
             let node = self.window.rendered_frame.dispatch_tree.node(*node_id);
             for DispatchActionListener {
                 action_type,
-                mut listener,
+                listener,
             } in node.action_listeners.clone()
             {
                 let any_action = action.as_any();
                 if action_type == any_action.type_id() {
-                    let listener = unsafe { listener.get_mut() };
                     listener(any_action, DispatchPhase::Capture, self);
                     if !self.propagate_event {
                         return;
@@ -1622,13 +1621,12 @@ impl<'a> WindowContext<'a> {
             let node = self.window.rendered_frame.dispatch_tree.node(*node_id);
             for DispatchActionListener {
                 action_type,
-                mut listener,
+                listener,
             } in node.action_listeners.clone()
             {
                 let any_action = action.as_any();
                 if action_type == any_action.type_id() {
                     self.propagate_event = false; // Actions stop propagation by default during the bubble phase
-                    let listener = unsafe { listener.get_mut() };
                     listener(any_action, DispatchPhase::Bubble, self);
                     if !self.propagate_event {
                         return;