Checkpoint

Antonio Scandurra created

Change summary

crates/gpui3/src/elements/div.rs       | 46 +++++++++----
crates/gpui3/src/elements/img.rs       | 15 +--
crates/gpui3/src/elements/svg.rs       | 14 +--
crates/gpui3/src/focusable.rs          | 95 ++++++++++++++++++++-------
crates/gpui3/src/interactive.rs        | 19 +---
crates/storybook2/src/stories/focus.rs |  8 +-
6 files changed, 120 insertions(+), 77 deletions(-)

Detailed changes

crates/gpui3/src/elements/div.rs 🔗

@@ -141,14 +141,14 @@ where
     pub fn compute_style(
         &self,
         bounds: Bounds<Pixels>,
-        state: &InteractiveElementState,
+        element_state: &DivState,
         cx: &mut ViewContext<V>,
     ) -> Style {
         let mut computed_style = Style::default();
         computed_style.refine(&self.base_style);
         self.focus.refine_style(&mut computed_style, cx);
         self.interaction
-            .refine_style(&mut computed_style, bounds, state, cx);
+            .refine_style(&mut computed_style, bounds, &element_state.interactive, cx);
         computed_style
     }
 }
@@ -157,13 +157,23 @@ impl<V> Div<V, StatefulInteraction<V>, FocusDisabled>
 where
     V: 'static + Send + Sync,
 {
-    pub fn focusable(
+    pub fn focusable(self) -> Div<V, StatefulInteraction<V>, FocusEnabled<V>> {
+        Div {
+            interaction: self.interaction,
+            focus: FocusEnabled::new(),
+            children: self.children,
+            group: self.group,
+            base_style: self.base_style,
+        }
+    }
+
+    pub fn track_focus(
         self,
         handle: &FocusHandle,
     ) -> Div<V, StatefulInteraction<V>, FocusEnabled<V>> {
         Div {
             interaction: self.interaction,
-            focus: handle.clone().into(),
+            focus: FocusEnabled::tracked(handle),
             children: self.children,
             group: self.group,
             base_style: self.base_style,
@@ -175,7 +185,7 @@ impl<V> Div<V, StatelessInteraction<V>, FocusDisabled>
 where
     V: 'static + Send + Sync,
 {
-    pub fn focusable(
+    pub fn track_focus(
         self,
         handle: &FocusHandle,
     ) -> Div<V, StatefulInteraction<V>, FocusEnabled<V>> {
@@ -198,10 +208,6 @@ where
         &mut self.focus.focus_listeners
     }
 
-    fn handle(&self) -> &FocusHandle {
-        &self.focus.focus_handle
-    }
-
     fn set_focus_style(&mut self, style: StyleRefinement) {
         self.focus.focus_style = style;
     }
@@ -215,6 +221,12 @@ where
     }
 }
 
+#[derive(Default)]
+pub struct DivState {
+    interactive: InteractiveElementState,
+    focus_handle: Option<FocusHandle>,
+}
+
 impl<V, I, F> Element for Div<V, I, F>
 where
     I: ElementInteraction<V>,
@@ -222,7 +234,7 @@ where
     V: 'static + Send + Sync,
 {
     type ViewState = V;
-    type ElementState = InteractiveElementState;
+    type ElementState = DivState;
 
     fn id(&self) -> Option<ElementId> {
         self.interaction
@@ -236,14 +248,17 @@ where
         element_state: Option<Self::ElementState>,
         cx: &mut ViewContext<Self::ViewState>,
     ) -> Self::ElementState {
-        self.focus.initialize(cx, |focus_handle, cx| {
-            self.interaction
-                .initialize(element_state, focus_handle, cx, |cx| {
+        let mut element_state = element_state.unwrap_or_default();
+        self.focus
+            .initialize(element_state.focus_handle.take(), cx, |focus_handle, cx| {
+                element_state.focus_handle = focus_handle;
+                self.interaction.initialize(cx, |cx| {
                     for child in &mut self.children {
                         child.initialize(view_state, cx);
                     }
                 })
-        })
+            });
+        element_state
     }
 
     fn layout(
@@ -286,7 +301,8 @@ where
                     style.paint(bounds, cx);
 
                     this.focus.paint(bounds, cx);
-                    this.interaction.paint(bounds, element_state, cx);
+                    this.interaction
+                        .paint(bounds, &element_state.interactive, cx);
                 });
 
                 cx.stack(1, |cx| {

crates/gpui3/src/elements/img.rs 🔗

@@ -1,9 +1,8 @@
 use crate::{
-    div, AnyElement, BorrowWindow, Bounds, Div, Element, ElementFocus, ElementId,
-    ElementInteraction, FocusDisabled, FocusEnabled, FocusListeners, Focusable,
-    InteractiveElementState, IntoAnyElement, LayoutId, Pixels, SharedString, StatefulInteraction,
-    StatefulInteractive, StatelessInteraction, StatelessInteractive, StyleRefinement, Styled,
-    ViewContext,
+    div, AnyElement, BorrowWindow, Bounds, Div, DivState, Element, ElementFocus, ElementId,
+    ElementInteraction, FocusDisabled, FocusEnabled, FocusListeners, Focusable, IntoAnyElement,
+    LayoutId, Pixels, SharedString, StatefulInteraction, StatefulInteractive, StatelessInteraction,
+    StatelessInteractive, StyleRefinement, Styled, ViewContext,
 };
 use futures::FutureExt;
 use util::ResultExt;
@@ -78,7 +77,7 @@ where
     F: ElementFocus<V>,
 {
     type ViewState = V;
-    type ElementState = InteractiveElementState;
+    type ElementState = DivState;
 
     fn id(&self) -> Option<crate::ElementId> {
         self.base.id()
@@ -192,8 +191,4 @@ where
     fn set_in_focus_style(&mut self, style: StyleRefinement) {
         self.base.set_in_focus_style(style)
     }
-
-    fn handle(&self) -> &crate::FocusHandle {
-        self.base.handle()
-    }
 }

crates/gpui3/src/elements/svg.rs 🔗

@@ -1,8 +1,8 @@
 use crate::{
-    div, AnyElement, Bounds, Div, Element, ElementFocus, ElementId, ElementInteraction,
-    FocusDisabled, FocusEnabled, FocusListeners, Focusable, InteractiveElementState,
-    IntoAnyElement, LayoutId, Pixels, SharedString, StatefulInteraction, StatefulInteractive,
-    StatelessInteraction, StatelessInteractive, StyleRefinement, Styled, ViewContext,
+    div, AnyElement, Bounds, Div, DivState, Element, ElementFocus, ElementId, ElementInteraction,
+    FocusDisabled, FocusEnabled, FocusListeners, Focusable, IntoAnyElement, LayoutId, Pixels,
+    SharedString, StatefulInteraction, StatefulInteractive, StatelessInteraction,
+    StatelessInteractive, StyleRefinement, Styled, ViewContext,
 };
 use util::ResultExt;
 
@@ -68,7 +68,7 @@ where
     F: ElementFocus<V>,
 {
     type ViewState = V;
-    type ElementState = InteractiveElementState;
+    type ElementState = DivState;
 
     fn id(&self) -> Option<crate::ElementId> {
         self.base.id()
@@ -165,8 +165,4 @@ where
     fn set_in_focus_style(&mut self, style: StyleRefinement) {
         self.base.set_in_focus_style(style)
     }
-
-    fn handle(&self) -> &crate::FocusHandle {
-        self.base.handle()
-    }
 }

crates/gpui3/src/focusable.rs 🔗

@@ -9,14 +9,13 @@ use std::sync::Arc;
 pub type FocusListeners<V> = SmallVec<[FocusListener<V>; 2]>;
 
 pub type FocusListener<V> =
-    Arc<dyn Fn(&mut V, &FocusEvent, &mut ViewContext<V>) + Send + Sync + 'static>;
+    Arc<dyn Fn(&mut V, &FocusHandle, &FocusEvent, &mut ViewContext<V>) + Send + Sync + 'static>;
 
 pub trait Focusable: Element {
     fn focus_listeners(&mut self) -> &mut FocusListeners<Self::ViewState>;
     fn set_focus_style(&mut self, style: StyleRefinement);
     fn set_focus_in_style(&mut self, style: StyleRefinement);
     fn set_in_focus_style(&mut self, style: StyleRefinement);
-    fn handle(&self) -> &FocusHandle;
 
     fn focus(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
     where
@@ -52,10 +51,9 @@ pub trait Focusable: Element {
     where
         Self: Sized,
     {
-        let handle = self.handle().clone();
         self.focus_listeners()
-            .push(Arc::new(move |view, event, cx| {
-                if event.focused.as_ref() == Some(&handle) {
+            .push(Arc::new(move |view, focus_handle, event, cx| {
+                if event.focused.as_ref() == Some(focus_handle) {
                     listener(view, event, cx)
                 }
             }));
@@ -72,10 +70,9 @@ pub trait Focusable: Element {
     where
         Self: Sized,
     {
-        let handle = self.handle().clone();
         self.focus_listeners()
-            .push(Arc::new(move |view, event, cx| {
-                if event.blurred.as_ref() == Some(&handle) {
+            .push(Arc::new(move |view, focus_handle, event, cx| {
+                if event.blurred.as_ref() == Some(focus_handle) {
                     listener(view, event, cx)
                 }
             }));
@@ -92,17 +89,16 @@ pub trait Focusable: Element {
     where
         Self: Sized,
     {
-        let handle = self.handle().clone();
         self.focus_listeners()
-            .push(Arc::new(move |view, event, cx| {
+            .push(Arc::new(move |view, focus_handle, event, cx| {
                 let descendant_blurred = event
                     .blurred
                     .as_ref()
-                    .map_or(false, |blurred| handle.contains(blurred, cx));
+                    .map_or(false, |blurred| focus_handle.contains(blurred, cx));
                 let descendant_focused = event
                     .focused
                     .as_ref()
-                    .map_or(false, |focused| handle.contains(focused, cx));
+                    .map_or(false, |focused| focus_handle.contains(focused, cx));
 
                 if !descendant_blurred && descendant_focused {
                     listener(view, event, cx)
@@ -121,17 +117,16 @@ pub trait Focusable: Element {
     where
         Self: Sized,
     {
-        let handle = self.handle().clone();
         self.focus_listeners()
-            .push(Arc::new(move |view, event, cx| {
+            .push(Arc::new(move |view, focus_handle, event, cx| {
                 let descendant_blurred = event
                     .blurred
                     .as_ref()
-                    .map_or(false, |blurred| handle.contains(blurred, cx));
+                    .map_or(false, |blurred| focus_handle.contains(blurred, cx));
                 let descendant_focused = event
                     .focused
                     .as_ref()
-                    .map_or(false, |focused| handle.contains(focused, cx));
+                    .map_or(false, |focused| focus_handle.contains(focused, cx));
                 if descendant_blurred && !descendant_focused {
                     listener(view, event, cx)
                 }
@@ -142,17 +137,25 @@ pub trait Focusable: Element {
 
 pub trait ElementFocus<V: 'static + Send + Sync>: 'static + Send + Sync {
     fn as_focusable(&self) -> Option<&FocusEnabled<V>>;
+    fn as_focusable_mut(&mut self) -> Option<&mut FocusEnabled<V>>;
 
     fn initialize<R>(
-        &self,
+        &mut self,
+        focus_handle: Option<FocusHandle>,
         cx: &mut ViewContext<V>,
         f: impl FnOnce(Option<FocusHandle>, &mut ViewContext<V>) -> R,
     ) -> R {
-        if let Some(focusable) = self.as_focusable() {
+        if let Some(focusable) = self.as_focusable_mut() {
+            let focus_handle = focusable
+                .focus_handle
+                .get_or_insert_with(|| focus_handle.unwrap_or_else(|| cx.focus_handle()))
+                .clone();
             for listener in focusable.focus_listeners.iter().cloned() {
-                cx.on_focus_changed(move |view, event, cx| listener(view, event, cx));
+                let focus_handle = focus_handle.clone();
+                cx.on_focus_changed(move |view, event, cx| {
+                    listener(view, &focus_handle, event, cx)
+                });
             }
-            let focus_handle = focusable.focus_handle.clone();
             cx.with_focus(focus_handle.clone(), |cx| f(Some(focus_handle), cx))
         } else {
             f(None, cx)
@@ -161,15 +164,19 @@ pub trait ElementFocus<V: 'static + Send + Sync>: 'static + Send + Sync {
 
     fn refine_style(&self, style: &mut Style, cx: &WindowContext) {
         if let Some(focusable) = self.as_focusable() {
-            if focusable.focus_handle.contains_focused(cx) {
+            let focus_handle = focusable
+                .focus_handle
+                .as_ref()
+                .expect("must call initialize before refine_style");
+            if focus_handle.contains_focused(cx) {
                 style.refine(&focusable.focus_in_style);
             }
 
-            if focusable.focus_handle.within_focused(cx) {
+            if focus_handle.within_focused(cx) {
                 style.refine(&focusable.in_focus_style);
             }
 
-            if focusable.focus_handle.is_focused(cx) {
+            if focus_handle.is_focused(cx) {
                 style.refine(&focusable.focus_style);
             }
         }
@@ -177,7 +184,10 @@ pub trait ElementFocus<V: 'static + Send + Sync>: 'static + Send + Sync {
 
     fn paint(&self, bounds: Bounds<Pixels>, cx: &mut WindowContext) {
         if let Some(focusable) = self.as_focusable() {
-            let focus_handle = focusable.focus_handle.clone();
+            let focus_handle = focusable
+                .focus_handle
+                .clone()
+                .expect("must call initialize before paint");
             cx.on_mouse_event(move |event: &MouseDownEvent, phase, cx| {
                 if phase == DispatchPhase::Bubble && bounds.contains_point(&event.position) {
                     if !cx.default_prevented() {
@@ -191,13 +201,38 @@ pub trait ElementFocus<V: 'static + Send + Sync>: 'static + Send + Sync {
 }
 
 pub struct FocusEnabled<V: 'static + Send + Sync> {
-    pub focus_handle: FocusHandle,
+    pub focus_handle: Option<FocusHandle>,
     pub focus_listeners: FocusListeners<V>,
     pub focus_style: StyleRefinement,
     pub focus_in_style: StyleRefinement,
     pub in_focus_style: StyleRefinement,
 }
 
+impl<V> FocusEnabled<V>
+where
+    V: 'static + Send + Sync,
+{
+    pub fn new() -> Self {
+        Self {
+            focus_handle: None,
+            focus_listeners: FocusListeners::default(),
+            focus_style: StyleRefinement::default(),
+            focus_in_style: StyleRefinement::default(),
+            in_focus_style: StyleRefinement::default(),
+        }
+    }
+
+    pub fn tracked(handle: &FocusHandle) -> Self {
+        Self {
+            focus_handle: Some(handle.clone()),
+            focus_listeners: FocusListeners::default(),
+            focus_style: StyleRefinement::default(),
+            focus_in_style: StyleRefinement::default(),
+            in_focus_style: StyleRefinement::default(),
+        }
+    }
+}
+
 impl<V> ElementFocus<V> for FocusEnabled<V>
 where
     V: 'static + Send + Sync,
@@ -205,6 +240,10 @@ where
     fn as_focusable(&self) -> Option<&FocusEnabled<V>> {
         Some(self)
     }
+
+    fn as_focusable_mut(&mut self) -> Option<&mut FocusEnabled<V>> {
+        Some(self)
+    }
 }
 
 impl<V> From<FocusHandle> for FocusEnabled<V>
@@ -213,7 +252,7 @@ where
 {
     fn from(value: FocusHandle) -> Self {
         Self {
-            focus_handle: value,
+            focus_handle: Some(value),
             focus_listeners: FocusListeners::default(),
             focus_style: StyleRefinement::default(),
             focus_in_style: StyleRefinement::default(),
@@ -231,4 +270,8 @@ where
     fn as_focusable(&self) -> Option<&FocusEnabled<V>> {
         None
     }
+
+    fn as_focusable_mut(&mut self) -> Option<&mut FocusEnabled<V>> {
+        None
+    }
 }

crates/gpui3/src/interactive.rs 🔗

@@ -305,13 +305,11 @@ pub trait ElementInteraction<V: 'static + Send + Sync>: 'static + Send + Sync {
     fn as_stateful(&self) -> Option<&StatefulInteraction<V>>;
     fn as_stateful_mut(&mut self) -> Option<&mut StatefulInteraction<V>>;
 
-    fn initialize(
+    fn initialize<R>(
         &mut self,
-        element_state: Option<InteractiveElementState>,
-        focus_handle: Option<FocusHandle>,
         cx: &mut ViewContext<V>,
-        f: impl FnOnce(&mut ViewContext<V>),
-    ) -> InteractiveElementState {
+        f: impl FnOnce(&mut ViewContext<V>) -> R,
+    ) -> R {
         if let Some(stateful) = self.as_stateful_mut() {
             cx.with_element_id(stateful.id.clone(), |global_id, cx| {
                 stateful.key_listeners.push((
@@ -329,19 +327,15 @@ pub trait ElementInteraction<V: 'static + Send + Sync>: 'static + Send + Sync {
                         None
                     }),
                 ));
-                let mut element_state = stateful.stateless.initialize(element_state, None, cx, f);
-                element_state.focus_handle = focus_handle
-                    .or(element_state.focus_handle.take())
-                    .or_else(|| cx.focused());
+                let result = stateful.stateless.initialize(cx, f);
                 stateful.key_listeners.pop();
-                element_state
+                result
             })
         } else {
             let stateless = self.as_stateless();
             cx.with_key_dispatch_context(stateless.dispatch_context.clone(), |cx| {
                 cx.with_key_listeners(&stateless.key_listeners, f)
-            });
-            element_state.unwrap_or_default()
+            })
         }
     }
 
@@ -613,7 +607,6 @@ impl ActiveState {
 
 #[derive(Default)]
 pub struct InteractiveElementState {
-    focus_handle: Option<FocusHandle>,
     active_state: Arc<Mutex<ActiveState>>,
     pending_click: Arc<Mutex<Option<MouseDownEvent>>>,
 }

crates/storybook2/src/stories/focus.rs 🔗

@@ -76,12 +76,12 @@ impl FocusStory {
         let color_5 = theme.lowest.variant.default.foreground;
         let color_6 = theme.highest.negative.default.foreground;
 
-        let parent = cx.focus_handle();
         let child_1 = cx.focus_handle();
         let child_2 = cx.focus_handle();
         view(cx.entity(|cx| ()), move |_, cx| {
             div()
-                .focusable(&parent)
+                .id("parent")
+                .focusable()
                 .context("parent")
                 .on_action(|_, action: &ActionA, phase, cx| {
                     println!("Action A dispatched on parent during {:?}", phase);
@@ -105,7 +105,7 @@ impl FocusStory {
                 .focus_in(|style| style.bg(color_3))
                 .child(
                     div()
-                        .focusable(&child_1)
+                        .track_focus(&child_1)
                         .context("child-1")
                         .on_action(|_, action: &ActionB, phase, cx| {
                             println!("Action B dispatched on child 1 during {:?}", phase);
@@ -129,7 +129,7 @@ impl FocusStory {
                 )
                 .child(
                     div()
-                        .focusable(&child_2)
+                        .track_focus(&child_2)
                         .context("child-2")
                         .on_action(|_, action: &ActionC, phase, cx| {
                             println!("Action C dispatched on child 2 during {:?}", phase);