Use an Arena to reuse allocations for listeners

Antonio Scandurra created

Change summary

crates/gpui2/src/arena.rs        | 124 ++++++++++++++++++++++++++++++++++
crates/gpui2/src/gpui2.rs        |   1 
crates/gpui2/src/key_dispatch.rs |  10 +-
crates/gpui2/src/window.rs       | 105 ++++++++++++++++------------
4 files changed, 191 insertions(+), 49 deletions(-)

Detailed changes

crates/gpui2/src/arena.rs 🔗

@@ -0,0 +1,124 @@
+use std::{
+    alloc,
+    ptr::{self, NonNull},
+};
+
+pub struct Arena {
+    start: NonNull<u8>,
+    offset: usize,
+    elements: Vec<ArenaElement>,
+}
+
+impl Default for Arena {
+    fn default() -> Self {
+        unsafe {
+            let layout = alloc::Layout::from_size_align(16 * 1024 * 1024, 1).unwrap();
+            let ptr = alloc::alloc(layout);
+            Self {
+                start: NonNull::new_unchecked(ptr),
+                offset: 0,
+                elements: Vec::new(),
+            }
+        }
+    }
+}
+
+struct ArenaElement {
+    value: NonNull<u8>,
+    drop: unsafe fn(NonNull<u8>),
+}
+
+impl Arena {
+    pub fn clear(&mut self) {
+        for element in self.elements.drain(..) {
+            unsafe {
+                (element.drop)(element.value);
+            }
+        }
+        self.offset = 0;
+    }
+
+    #[inline(always)]
+    pub fn alloc<T>(&mut self, value: T) -> ArenaRef<T> {
+        unsafe fn drop<T>(ptr: NonNull<u8>) {
+            std::ptr::drop_in_place(ptr.cast::<T>().as_ptr());
+        }
+
+        unsafe {
+            let layout = alloc::Layout::for_value(&value).pad_to_align();
+            let value_ptr = self.start.as_ptr().add(self.offset).cast::<T>();
+            ptr::write(value_ptr, value);
+
+            let value = NonNull::new_unchecked(value_ptr);
+            self.elements.push(ArenaElement {
+                value: value.cast(),
+                drop: drop::<T>,
+            });
+            self.offset += layout.size();
+            ArenaRef(value)
+        }
+    }
+}
+
+pub struct ArenaRef<T: ?Sized>(NonNull<T>);
+
+impl<T: ?Sized> Copy for ArenaRef<T> {}
+
+impl<T: ?Sized> Clone for ArenaRef<T> {
+    fn clone(&self) -> Self {
+        Self(self.0)
+    }
+}
+
+impl<T: ?Sized> ArenaRef<T> {
+    pub unsafe fn map<U: ?Sized>(mut self, f: impl FnOnce(&mut T) -> &mut U) -> ArenaRef<U> {
+        let u = f(self.get_mut());
+        ArenaRef(NonNull::new_unchecked(u))
+    }
+
+    pub unsafe fn get_mut(&mut self) -> &mut T {
+        self.0.as_mut()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::{cell::Cell, rc::Rc};
+
+    use super::*;
+
+    #[test]
+    fn test_arena() {
+        let mut arena = Arena::default();
+        let mut a = arena.alloc(1u64);
+        let mut b = arena.alloc(2u32);
+        let mut c = arena.alloc(3u16);
+        let mut d = arena.alloc(4u8);
+        assert_eq!(unsafe { *a.get_mut() }, 1);
+        assert_eq!(unsafe { *b.get_mut() }, 2);
+        assert_eq!(unsafe { *c.get_mut() }, 3);
+        assert_eq!(unsafe { *d.get_mut() }, 4);
+
+        arena.clear();
+        let mut a = arena.alloc(5u64);
+        let mut b = arena.alloc(6u32);
+        let mut c = arena.alloc(7u16);
+        let mut d = arena.alloc(8u8);
+        assert_eq!(unsafe { *a.get_mut() }, 5);
+        assert_eq!(unsafe { *b.get_mut() }, 6);
+        assert_eq!(unsafe { *c.get_mut() }, 7);
+        assert_eq!(unsafe { *d.get_mut() }, 8);
+
+        // Ensure drop gets called.
+        let dropped = Rc::new(Cell::new(false));
+        struct DropGuard(Rc<Cell<bool>>);
+        impl Drop for DropGuard {
+            fn drop(&mut self) {
+                self.0.set(true);
+            }
+        }
+        arena.alloc(DropGuard(dropped.clone()));
+        arena.clear();
+        assert!(dropped.get());
+    }
+}

crates/gpui2/src/key_dispatch.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{
-    Action, ActionRegistry, DispatchPhase, FocusId, KeyBinding, KeyContext, KeyMatch, Keymap,
-    Keystroke, KeystrokeMatcher, WindowContext,
+    arena::ArenaRef, Action, ActionRegistry, DispatchPhase, FocusId, KeyBinding, KeyContext,
+    KeyMatch, Keymap, Keystroke, KeystrokeMatcher, WindowContext,
 };
 use collections::HashMap;
 use parking_lot::Mutex;
@@ -33,12 +33,12 @@ pub(crate) struct DispatchNode {
     parent: Option<DispatchNodeId>,
 }
 
-type KeyListener = Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>;
+type KeyListener = ArenaRef<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>;
 
 #[derive(Clone)]
 pub(crate) struct DispatchActionListener {
     pub(crate) action_type: TypeId,
-    pub(crate) listener: Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
+    pub(crate) listener: ArenaRef<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
 }
 
 impl DispatchTree {
@@ -117,7 +117,7 @@ impl DispatchTree {
     pub fn on_action(
         &mut self,
         action_type: TypeId,
-        listener: Rc<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
+        listener: ArenaRef<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
     ) {
         self.active_node()
             .action_listeners

crates/gpui2/src/window.rs 🔗

@@ -1,15 +1,17 @@
 use crate::{
-    key_dispatch::DispatchActionListener, px, size, transparent_black, Action, AnyDrag, AnyView,
-    AppContext, AsyncWindowContext, AvailableSpace, Bounds, BoxShadow, Context, Corners,
-    CursorStyle, DevicePixels, DispatchNodeId, DispatchTree, DisplayId, Edges, Effect, Entity,
-    EntityId, EventEmitter, FileDropEvent, Flatten, FontId, GlobalElementId, GlyphId, Hsla,
-    ImageData, InputEvent, IsZero, KeyBinding, KeyContext, KeyDownEvent, KeystrokeEvent, LayoutId,
-    Model, ModelContext, Modifiers, MonochromeSprite, MouseButton, MouseMoveEvent, MouseUpEvent,
-    Path, Pixels, PlatformAtlas, PlatformDisplay, PlatformInputHandler, PlatformWindow, Point,
-    PolychromeSprite, PromptLevel, Quad, Render, RenderGlyphParams, RenderImageParams,
-    RenderSvgParams, ScaledPixels, Scene, SceneBuilder, Shadow, SharedString, Size, Style,
-    SubscriberSet, Subscription, Surface, TaffyLayoutEngine, Task, Underline, UnderlineStyle, View,
-    VisualContext, WeakView, WindowBounds, WindowOptions, SUBPIXEL_VARIANTS,
+    arena::{Arena, ArenaRef},
+    key_dispatch::DispatchActionListener,
+    px, size, transparent_black, Action, AnyDrag, AnyView, AppContext, AsyncWindowContext,
+    AvailableSpace, Bounds, BoxShadow, Context, Corners, CursorStyle, DevicePixels, DispatchNodeId,
+    DispatchTree, DisplayId, Edges, Effect, Entity, EntityId, EventEmitter, FileDropEvent, Flatten,
+    FontId, GlobalElementId, GlyphId, Hsla, ImageData, InputEvent, IsZero, KeyBinding, KeyContext,
+    KeyDownEvent, KeystrokeEvent, LayoutId, Model, ModelContext, Modifiers, MonochromeSprite,
+    MouseButton, MouseMoveEvent, MouseUpEvent, Path, Pixels, PlatformAtlas, PlatformDisplay,
+    PlatformInputHandler, PlatformWindow, Point, PolychromeSprite, PromptLevel, Quad, Render,
+    RenderGlyphParams, RenderImageParams, RenderSvgParams, ScaledPixels, Scene, SceneBuilder,
+    Shadow, SharedString, Size, Style, SubscriberSet, Subscription, Surface, TaffyLayoutEngine,
+    Task, Underline, UnderlineStyle, View, VisualContext, WeakView, WindowBounds, WindowOptions,
+    SUBPIXEL_VARIANTS,
 };
 use anyhow::{anyhow, Context as _, Result};
 use collections::FxHashMap;
@@ -85,7 +87,7 @@ impl DispatchPhase {
 }
 
 type AnyObserver = Box<dyn FnMut(&mut WindowContext) -> bool + 'static>;
-type AnyMouseListener = Box<dyn FnMut(&dyn Any, DispatchPhase, &mut WindowContext) + 'static>;
+type AnyMouseListener = ArenaRef<dyn FnMut(&dyn Any, DispatchPhase, &mut WindowContext) + 'static>;
 type AnyWindowFocusListener = Box<dyn FnMut(&FocusEvent, &mut WindowContext) -> bool + 'static>;
 
 struct FocusEvent {
@@ -268,9 +270,9 @@ pub(crate) struct ElementStateBox {
     type_name: &'static str,
 }
 
-// #[derive(Default)]
 pub(crate) struct Frame {
     focus: Option<FocusId>,
+    arena: Arena,
     pub(crate) element_states: FxHashMap<GlobalElementId, ElementStateBox>,
     mouse_listeners: FxHashMap<TypeId, Vec<(StackingOrder, AnyMouseListener)>>,
     pub(crate) dispatch_tree: DispatchTree,
@@ -285,6 +287,7 @@ impl Frame {
     fn new(dispatch_tree: DispatchTree) -> Self {
         Frame {
             focus: None,
+            arena: Arena::default(),
             element_states: FxHashMap::default(),
             mouse_listeners: FxHashMap::default(),
             dispatch_tree,
@@ -299,6 +302,7 @@ impl Frame {
     fn clear(&mut self) {
         self.element_states.clear();
         self.mouse_listeners.values_mut().for_each(Vec::clear);
+        self.arena.clear();
         self.dispatch_tree.clear();
         self.depth_map.clear();
     }
@@ -818,25 +822,23 @@ impl<'a> WindowContext<'a> {
     /// Register a mouse event listener on the window for the next frame. The type of event
     /// is determined by the first parameter of the given listener. When the next frame is rendered
     /// the listener will be cleared.
-    ///
-    /// This is a fairly low-level method, so prefer using event handlers on elements unless you have
-    /// a specific need to register a global listener.
     pub fn on_mouse_event<Event: 'static>(
         &mut self,
         mut handler: impl FnMut(&Event, DispatchPhase, &mut WindowContext) + 'static,
     ) {
         let order = self.window.next_frame.z_index_stack.clone();
+        let handler = self.window.next_frame.arena.alloc(
+            move |event: &dyn Any, phase: DispatchPhase, cx: &mut WindowContext<'_>| {
+                handler(event.downcast_ref().unwrap(), phase, cx)
+            },
+        );
+        let handler = unsafe { handler.map(|handler| handler as _) };
         self.window
             .next_frame
             .mouse_listeners
             .entry(TypeId::of::<Event>())
             .or_default()
-            .push((
-                order,
-                Box::new(move |event: &dyn Any, phase, cx| {
-                    handler(event.downcast_ref().unwrap(), phase, cx)
-                }),
-            ))
+            .push((order, handler))
     }
 
     /// Register a key event listener on the window for the next frame. The type of event
@@ -847,16 +849,17 @@ impl<'a> WindowContext<'a> {
     /// a specific need to register a global listener.
     pub fn on_key_event<Event: 'static>(
         &mut self,
-        handler: impl Fn(&Event, DispatchPhase, &mut WindowContext) + 'static,
+        listener: impl Fn(&Event, DispatchPhase, &mut WindowContext) + 'static,
     ) {
-        self.window
-            .next_frame
-            .dispatch_tree
-            .on_key_event(Rc::new(move |event, phase, cx| {
+        let listener = self.window.next_frame.arena.alloc(
+            move |event: &dyn Any, phase, cx: &mut WindowContext<'_>| {
                 if let Some(event) = event.downcast_ref::<Event>() {
-                    handler(event, phase, cx)
+                    listener(event, phase, cx)
                 }
-            }));
+            },
+        );
+        let listener = unsafe { listener.map(|handler| handler as _) };
+        self.window.next_frame.dispatch_tree.on_key_event(listener);
     }
 
     /// Register an action listener on the window for the next frame. The type of action
@@ -868,12 +871,14 @@ impl<'a> WindowContext<'a> {
     pub fn on_action(
         &mut self,
         action_type: TypeId,
-        handler: impl Fn(&dyn Any, DispatchPhase, &mut WindowContext) + 'static,
+        listener: impl Fn(&dyn Any, DispatchPhase, &mut WindowContext) + 'static,
     ) {
-        self.window.next_frame.dispatch_tree.on_action(
-            action_type,
-            Rc::new(move |action, phase, cx| handler(action, phase, cx)),
-        );
+        let listener = self.window.next_frame.arena.alloc(listener);
+        let listener = unsafe { listener.map(|handler| handler as _) };
+        self.window
+            .next_frame
+            .dispatch_tree
+            .on_action(action_type, listener);
     }
 
     pub fn is_action_available(&self, action: &dyn Action) -> bool {
@@ -1274,10 +1279,16 @@ impl<'a> WindowContext<'a> {
             cx.with_key_dispatch(Some(KeyContext::default()), None, |_, cx| {
                 for (action_type, action_listeners) in &cx.app.global_action_listeners {
                     for action_listener in action_listeners.iter().cloned() {
-                        cx.window.next_frame.dispatch_tree.on_action(
-                            *action_type,
-                            Rc::new(move |action, phase, cx| action_listener(action, phase, cx)),
-                        )
+                        let listener = cx.window.next_frame.arena.alloc(
+                            move |action: &dyn Any, phase, cx: &mut WindowContext<'_>| {
+                                action_listener(action, phase, cx)
+                            },
+                        );
+                        let listener = unsafe { listener.map(|listener| listener as _) };
+                        cx.window
+                            .next_frame
+                            .dispatch_tree
+                            .on_action(*action_type, listener)
                     }
                 }
 
@@ -1460,6 +1471,7 @@ impl<'a> WindowContext<'a> {
             // Capture phase, events bubble from back to front. Handlers for this phase are used for
             // special purposes, such as detecting events outside of a given Bounds.
             for (_, handler) in &mut handlers {
+                let handler = unsafe { handler.get_mut() };
                 handler(event, DispatchPhase::Capture, self);
                 if !self.app.propagate_event {
                     break;
@@ -1469,6 +1481,7 @@ impl<'a> WindowContext<'a> {
             // Bubble phase, where most normal handlers do their work.
             if self.app.propagate_event {
                 for (_, handler) in handlers.iter_mut().rev() {
+                    let handler = unsafe { handler.get_mut() };
                     handler(event, DispatchPhase::Bubble, self);
                     if !self.app.propagate_event {
                         break;
@@ -1518,7 +1531,8 @@ impl<'a> WindowContext<'a> {
                 context_stack.push(context);
             }
 
-            for key_listener in node.key_listeners.clone() {
+            for mut key_listener in node.key_listeners.clone() {
+                let key_listener = unsafe { key_listener.get_mut() };
                 key_listener(event, DispatchPhase::Capture, self);
                 if !self.propagate_event {
                     return;
@@ -1530,7 +1544,8 @@ impl<'a> WindowContext<'a> {
         for node_id in dispatch_path.iter().rev() {
             // Handle low level key events
             let node = self.window.rendered_frame.dispatch_tree.node(*node_id);
-            for key_listener in node.key_listeners.clone() {
+            for mut key_listener in node.key_listeners.clone() {
+                let key_listener = unsafe { key_listener.get_mut() };
                 key_listener(event, DispatchPhase::Bubble, self);
                 if !self.propagate_event {
                     return;
@@ -1582,11 +1597,12 @@ impl<'a> WindowContext<'a> {
             let node = self.window.rendered_frame.dispatch_tree.node(*node_id);
             for DispatchActionListener {
                 action_type,
-                listener,
+                mut listener,
             } in node.action_listeners.clone()
             {
                 let any_action = action.as_any();
                 if action_type == any_action.type_id() {
+                    let listener = unsafe { listener.get_mut() };
                     listener(any_action, DispatchPhase::Capture, self);
                     if !self.propagate_event {
                         return;
@@ -1599,12 +1615,13 @@ impl<'a> WindowContext<'a> {
             let node = self.window.rendered_frame.dispatch_tree.node(*node_id);
             for DispatchActionListener {
                 action_type,
-                listener,
+                mut listener,
             } in node.action_listeners.clone()
             {
                 let any_action = action.as_any();
                 if action_type == any_action.type_id() {
                     self.propagate_event = false; // Actions stop propagation by default during the bubble phase
+                    let listener = unsafe { listener.get_mut() };
                     listener(any_action, DispatchPhase::Bubble, self);
                     if !self.propagate_event {
                         return;
@@ -2593,13 +2610,13 @@ impl<'a, V: 'static> ViewContext<'a, V> {
     pub fn on_action(
         &mut self,
         action_type: TypeId,
-        handler: impl Fn(&mut V, &dyn Any, DispatchPhase, &mut ViewContext<V>) + 'static,
+        listener: impl Fn(&mut V, &dyn Any, DispatchPhase, &mut ViewContext<V>) + 'static,
     ) {
         let handle = self.view().clone();
         self.window_cx
             .on_action(action_type, move |action, phase, cx| {
                 handle.update(cx, |view, cx| {
-                    handler(view, action, phase, cx);
+                    listener(view, action, phase, cx);
                 })
             });
     }