Checkpoint

Nathan Sobo created

Change summary

crates/gpui2/src/elements/node.rs | 264 +++++++++++++++++++-------------
1 file changed, 153 insertions(+), 111 deletions(-)

Detailed changes

crates/gpui2/src/elements/node.rs 🔗

@@ -557,17 +557,20 @@ impl<V: 'static> Element<V> for Node<V> {
     fn initialize(
         &mut self,
         view_state: &mut V,
-        previous_element_state: Option<Self::ElementState>,
+        element_state: Option<Self::ElementState>,
         cx: &mut ViewContext<V>,
     ) -> Self::ElementState {
-        for child in &mut self.children {
-            child.initialize(view_state, cx);
-        }
+        let interactive_state =
+            self.interactivity
+                .initialize(element_state.map(|s| s.interactive_state), cx, |cx| {
+                    for child in &mut self.children {
+                        child.initialize(view_state, cx);
+                    }
+                });
+
         NodeState {
+            interactive_state,
             child_layout_ids: SmallVec::new(),
-            interactive_state: previous_element_state
-                .map(|s| s.interactive_state)
-                .unwrap_or_default(),
         }
     }
 
@@ -664,7 +667,7 @@ pub struct Interactivity<V> {
     hovered: bool,
     group_hovered: bool,
     key_context: KeyContext,
-    focus_handle: Option<FocusHandle>,
+    tracked_focus_handle: Option<FocusHandle>,
     focusable: bool,
     scroll_offset: Point<Pixels>,
     group: Option<SharedString>,
@@ -692,110 +695,29 @@ pub struct Interactivity<V> {
     tooltip_builder: Option<TooltipBuilder<V>>,
 }
 
-#[derive(Default)]
-pub struct InteractiveElementState {
-    focus_handle: Option<FocusHandle>,
-    clicked_state: Arc<Mutex<ElementClickedState>>,
-    hover_state: Arc<Mutex<bool>>,
-    pending_mouse_down: Arc<Mutex<Option<MouseDownEvent>>>,
-    scroll_offset: Option<Arc<Mutex<Point<Pixels>>>>,
-    active_tooltip: Arc<Mutex<Option<ActiveTooltip>>>,
-}
-
-struct ActiveTooltip {
-    #[allow(unused)] // used to drop the task
-    waiting: Option<Task<()>>,
-    tooltip: Option<AnyTooltip>,
-}
-
-/// Whether or not the element or a group that contains it is clicked by the mouse.
-#[derive(Copy, Clone, Default, Eq, PartialEq)]
-struct ElementClickedState {
-    pub group: bool,
-    pub element: bool,
-}
-
-impl ElementClickedState {
-    fn is_clicked(&self) -> bool {
-        self.group || self.element
-    }
-}
-
 impl<V> Interactivity<V>
 where
     V: 'static,
 {
-    fn compute_style(
-        &self,
-        bounds: Option<Bounds<Pixels>>,
-        element_state: &mut InteractiveElementState,
+    fn initialize(
+        &mut self,
+        element_state: Option<InteractiveElementState>,
         cx: &mut ViewContext<V>,
-    ) -> Style {
-        let mut style = Style::default();
-        style.refine(&self.base_style);
-
-        if let Some(focus_handle) = self.focus_handle.as_ref() {
-            if focus_handle.contains_focused(cx) {
-                style.refine(&self.focus_in_style);
-            }
-
-            if focus_handle.within_focused(cx) {
-                style.refine(&self.in_focus_style);
-            }
-
-            if focus_handle.is_focused(cx) {
-                style.refine(&self.focus_style);
-            }
-        }
-
-        if let Some(bounds) = bounds {
-            let mouse_position = cx.mouse_position();
-            if let Some(group_hover) = self.group_hover_style.as_ref() {
-                if let Some(group_bounds) = GroupBounds::get(&group_hover.group, cx) {
-                    if group_bounds.contains_point(&mouse_position) {
-                        style.refine(&group_hover.style);
-                    }
-                }
-            }
-            if bounds.contains_point(&mouse_position) {
-                style.refine(&self.hover_style);
-            }
-
-            if let Some(drag) = cx.active_drag.take() {
-                for (state_type, group_drag_style) in &self.group_drag_over_styles {
-                    if let Some(group_bounds) = GroupBounds::get(&group_drag_style.group, cx) {
-                        if *state_type == drag.view.entity_type()
-                            && group_bounds.contains_point(&mouse_position)
-                        {
-                            style.refine(&group_drag_style.style);
-                        }
-                    }
-                }
-
-                for (state_type, drag_over_style) in &self.drag_over_styles {
-                    if *state_type == drag.view.entity_type()
-                        && bounds.contains_point(&mouse_position)
-                    {
-                        style.refine(drag_over_style);
-                    }
-                }
-
-                cx.active_drag = Some(drag);
-            }
-        }
-
-        let clicked_state = element_state.clicked_state.lock();
-        if clicked_state.group {
-            if let Some(group) = self.group_active_style.as_ref() {
-                style.refine(&group.style)
-            }
-        }
-
-        if clicked_state.element {
-            style.refine(&self.active_style)
+        f: impl FnOnce(&mut ViewContext<V>),
+    ) -> InteractiveElementState {
+        let mut element_state = element_state.unwrap_or_default();
+        // Ensure we store a focus handle in our element state if we're focusable.
+        // If there's an explicit focus handle we're tracking, use that. Otherwise
+        // create a new handle and store it in the element state, which lives for as
+        // as frames contain an element with this id.
+        if self.focusable {
+            element_state.focus_handle.get_or_insert_with(|| {
+                self.tracked_focus_handle
+                    .clone()
+                    .unwrap_or_else(|| cx.focus_handle())
+            });
         }
-
-        style
+        element_state
     }
 
     fn layout(
@@ -807,7 +729,7 @@ where
         let style = self.compute_style(None, element_state, cx);
         cx.with_key_dispatch(
             self.key_context.clone(),
-            self.focus_handle.clone(),
+            self.tracked_focus_handle.clone(),
             |_, cx| f(style, cx),
         )
     }
@@ -1091,14 +1013,105 @@ where
 
         cx.with_key_dispatch(
             self.key_context.clone(),
-            self.focus_handle.clone(),
-            |_, cx| f(style, self.scroll_offset, cx),
+            self.tracked_focus_handle.clone(),
+            |_, cx| {
+                for listener in self.key_down_listeners.drain(..) {
+                    cx.on_key_event(move |state, event: &KeyDownEvent, phase, cx| {
+                        listener(state, event, phase, cx);
+                    })
+                }
+
+                for listener in self.key_up_listeners.drain(..) {
+                    cx.on_key_event(move |state, event: &KeyUpEvent, phase, cx| {
+                        listener(state, event, phase, cx);
+                    })
+                }
+
+                for (action_type, listener) in self.action_listeners.drain(..) {
+                    cx.on_action(action_type, listener)
+                }
+
+                f(style, self.scroll_offset, cx)
+            },
         );
 
         if let Some(group) = self.group.as_ref() {
             GroupBounds::pop(group, cx);
         }
     }
+
+    fn compute_style(
+        &self,
+        bounds: Option<Bounds<Pixels>>,
+        element_state: &mut InteractiveElementState,
+        cx: &mut ViewContext<V>,
+    ) -> Style {
+        let mut style = Style::default();
+        style.refine(&self.base_style);
+
+        if let Some(focus_handle) = self.tracked_focus_handle.as_ref() {
+            if focus_handle.contains_focused(cx) {
+                style.refine(&self.focus_in_style);
+            }
+
+            if focus_handle.within_focused(cx) {
+                style.refine(&self.in_focus_style);
+            }
+
+            if focus_handle.is_focused(cx) {
+                style.refine(&self.focus_style);
+            }
+        }
+
+        if let Some(bounds) = bounds {
+            let mouse_position = cx.mouse_position();
+            if let Some(group_hover) = self.group_hover_style.as_ref() {
+                if let Some(group_bounds) = GroupBounds::get(&group_hover.group, cx) {
+                    if group_bounds.contains_point(&mouse_position) {
+                        style.refine(&group_hover.style);
+                    }
+                }
+            }
+            if bounds.contains_point(&mouse_position) {
+                style.refine(&self.hover_style);
+            }
+
+            if let Some(drag) = cx.active_drag.take() {
+                for (state_type, group_drag_style) in &self.group_drag_over_styles {
+                    if let Some(group_bounds) = GroupBounds::get(&group_drag_style.group, cx) {
+                        if *state_type == drag.view.entity_type()
+                            && group_bounds.contains_point(&mouse_position)
+                        {
+                            style.refine(&group_drag_style.style);
+                        }
+                    }
+                }
+
+                for (state_type, drag_over_style) in &self.drag_over_styles {
+                    if *state_type == drag.view.entity_type()
+                        && bounds.contains_point(&mouse_position)
+                    {
+                        style.refine(drag_over_style);
+                    }
+                }
+
+                cx.active_drag = Some(drag);
+            }
+        }
+
+        let clicked_state = element_state.clicked_state.lock();
+        if clicked_state.group {
+            if let Some(group) = self.group_active_style.as_ref() {
+                style.refine(&group.style)
+            }
+        }
+
+        if clicked_state.element {
+            style.refine(&self.active_style)
+        }
+
+        style
+    }
 }
 
 impl<V: 'static> Default for Interactivity<V> {
@@ -1107,7 +1120,7 @@ impl<V: 'static> Default for Interactivity<V> {
             hovered: false,
             group_hovered: false,
             key_context: KeyContext::default(),
-            focus_handle: None,
+            tracked_focus_handle: None,
             scroll_offset: Point::default(),
             group: None,
             base_style: StyleRefinement::default(),
@@ -1132,11 +1145,40 @@ impl<V: 'static> Default for Interactivity<V> {
             drag_listener: None,
             hover_listener: None,
             tooltip_builder: None,
-            focusable: todo!(),
+            focusable: false,
         }
     }
 }
 
+#[derive(Default)]
+pub struct InteractiveElementState {
+    focus_handle: Option<FocusHandle>,
+    clicked_state: Arc<Mutex<ElementClickedState>>,
+    hover_state: Arc<Mutex<bool>>,
+    pending_mouse_down: Arc<Mutex<Option<MouseDownEvent>>>,
+    scroll_offset: Option<Arc<Mutex<Point<Pixels>>>>,
+    active_tooltip: Arc<Mutex<Option<ActiveTooltip>>>,
+}
+
+struct ActiveTooltip {
+    #[allow(unused)] // used to drop the task
+    waiting: Option<Task<()>>,
+    tooltip: Option<AnyTooltip>,
+}
+
+/// Whether or not the element or a group that contains it is clicked by the mouse.
+#[derive(Copy, Clone, Default, Eq, PartialEq)]
+struct ElementClickedState {
+    pub group: bool,
+    pub element: bool,
+}
+
+impl ElementClickedState {
+    fn is_clicked(&self) -> bool {
+        self.group || self.element
+    }
+}
+
 #[derive(Default)]
 pub struct GroupBounds(HashMap<SharedString, SmallVec<[Bounds<Pixels>; 1]>>);