Prevent `ArenaRef` from being cloned

Antonio Scandurra created

Change summary

crates/gpui2/src/arena.rs        |  9 -----
crates/gpui2/src/key_dispatch.rs |  5 ++
crates/gpui2/src/window.rs       | 57 +++++++++++++++++++++++++--------
3 files changed, 46 insertions(+), 25 deletions(-)

Detailed changes

crates/gpui2/src/arena.rs 🔗

@@ -98,15 +98,6 @@ pub struct ArenaRef<T: ?Sized> {
     valid: Rc<Cell<bool>>,
 }
 
-impl<T: ?Sized> Clone for ArenaRef<T> {
-    fn clone(&self) -> Self {
-        Self {
-            ptr: self.ptr,
-            valid: self.valid.clone(),
-        }
-    }
-}
-
 impl<T: ?Sized> ArenaRef<T> {
     #[inline(always)]
     pub fn map<U: ?Sized>(mut self, f: impl FnOnce(&mut T) -> &mut U) -> ArenaRef<U> {

crates/gpui2/src/key_dispatch.rs 🔗

@@ -35,7 +35,6 @@ pub(crate) struct DispatchNode {
 
 type KeyListener = ArenaRef<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>;
 
-#[derive(Clone)]
 pub(crate) struct DispatchActionListener {
     pub(crate) action_type: TypeId,
     pub(crate) listener: ArenaRef<dyn Fn(&dyn Any, DispatchPhase, &mut WindowContext)>,
@@ -267,6 +266,10 @@ impl DispatchTree {
         &self.nodes[node_id.0]
     }
 
+    pub fn node_mut(&mut self, node_id: DispatchNodeId) -> &mut DispatchNode {
+        &mut self.nodes[node_id.0]
+    }
+
     fn active_node(&mut self) -> &mut DispatchNode {
         let active_node_id = self.active_node_id();
         &mut self.nodes[active_node_id.0]

crates/gpui2/src/window.rs 🔗

@@ -1572,30 +1572,43 @@ impl<'a> WindowContext<'a> {
         self.propagate_event = true;
 
         for node_id in &dispatch_path {
-            let node = self.window.rendered_frame.dispatch_tree.node(*node_id);
-
+            let node = self.window.rendered_frame.dispatch_tree.node_mut(*node_id);
             if let Some(context) = node.context.clone() {
                 context_stack.push(context);
             }
 
-            for key_listener in node.key_listeners.clone() {
+            let key_listeners = mem::take(&mut node.key_listeners);
+            for key_listener in &key_listeners {
                 key_listener(event, DispatchPhase::Capture, self);
                 if !self.propagate_event {
-                    return;
+                    break;
                 }
             }
+            let node = self.window.rendered_frame.dispatch_tree.node_mut(*node_id);
+            node.key_listeners = key_listeners;
+
+            if !self.propagate_event {
+                return;
+            }
         }
 
         // Bubble phase
         for node_id in dispatch_path.iter().rev() {
             // Handle low level key events
-            let node = self.window.rendered_frame.dispatch_tree.node(*node_id);
-            for key_listener in node.key_listeners.clone() {
+            let node = self.window.rendered_frame.dispatch_tree.node_mut(*node_id);
+            let key_listeners = mem::take(&mut node.key_listeners);
+            for key_listener in &key_listeners {
                 key_listener(event, DispatchPhase::Bubble, self);
                 if !self.propagate_event {
-                    return;
+                    break;
                 }
             }
+            let node = self.window.rendered_frame.dispatch_tree.node_mut(*node_id);
+            node.key_listeners = key_listeners;
+
+            if !self.propagate_event {
+                return;
+            }
 
             // Match keystrokes
             let node = self.window.rendered_frame.dispatch_tree.node(*node_id);
@@ -1639,38 +1652,52 @@ impl<'a> WindowContext<'a> {
 
         // Capture phase
         for node_id in &dispatch_path {
-            let node = self.window.rendered_frame.dispatch_tree.node(*node_id);
+            let node = self.window.rendered_frame.dispatch_tree.node_mut(*node_id);
+            let action_listeners = mem::take(&mut node.action_listeners);
             for DispatchActionListener {
                 action_type,
                 listener,
-            } in node.action_listeners.clone()
+            } in &action_listeners
             {
                 let any_action = action.as_any();
-                if action_type == any_action.type_id() {
+                if *action_type == any_action.type_id() {
                     listener(any_action, DispatchPhase::Capture, self);
                     if !self.propagate_event {
-                        return;
+                        break;
                     }
                 }
             }
+            let node = self.window.rendered_frame.dispatch_tree.node_mut(*node_id);
+            node.action_listeners = action_listeners;
+
+            if !self.propagate_event {
+                return;
+            }
         }
         // Bubble phase
         for node_id in dispatch_path.iter().rev() {
-            let node = self.window.rendered_frame.dispatch_tree.node(*node_id);
+            let node = self.window.rendered_frame.dispatch_tree.node_mut(*node_id);
+            let action_listeners = mem::take(&mut node.action_listeners);
             for DispatchActionListener {
                 action_type,
                 listener,
-            } in node.action_listeners.clone()
+            } in &action_listeners
             {
                 let any_action = action.as_any();
-                if action_type == any_action.type_id() {
+                if *action_type == any_action.type_id() {
                     self.propagate_event = false; // Actions stop propagation by default during the bubble phase
                     listener(any_action, DispatchPhase::Bubble, self);
                     if !self.propagate_event {
-                        return;
+                        break;
                     }
                 }
             }
+
+            let node = self.window.rendered_frame.dispatch_tree.node_mut(*node_id);
+            node.action_listeners = action_listeners;
+            if !self.propagate_event {
+                return;
+            }
         }
     }