Clear active state when drag starts

Nathan Sobo created

Change summary

crates/gpui2/src/interactive.rs  | 98 ++++++++++++++++++++++++++-------
crates/ui2/src/components/tab.rs |  3 
2 files changed, 79 insertions(+), 22 deletions(-)

Detailed changes

crates/gpui2/src/interactive.rs 🔗

@@ -336,6 +336,34 @@ pub trait StatefulInteractive: StatelessInteractive {
             }));
         self
     }
+
+    fn drag_over<S: 'static>(mut self, f: impl FnOnce(StyleRefinement) -> StyleRefinement) -> Self
+    where
+        Self: Sized,
+    {
+        self.stateful_interaction()
+            .drag_over_styles
+            .push((TypeId::of::<S>(), f(StyleRefinement::default())));
+        self
+    }
+
+    fn group_drag_over<S: 'static>(
+        mut self,
+        group_name: impl Into<SharedString>,
+        f: impl FnOnce(StyleRefinement) -> StyleRefinement,
+    ) -> Self
+    where
+        Self: Sized,
+    {
+        self.stateful_interaction().group_drag_over_styles.push((
+            TypeId::of::<S>(),
+            GroupStyle {
+                group: group_name.into(),
+                style: f(StyleRefinement::default()),
+            },
+        ));
+        self
+    }
 }
 
 pub trait ElementInteraction<V: 'static + Send + Sync>: 'static + Send + Sync {
@@ -398,6 +426,26 @@ pub trait ElementInteraction<V: 'static + Send + Sync>: 'static + Send + Sync {
             style.refine(&stateless.hover_style);
         }
 
+        if let Some(drag) = cx.active_drag.take() {
+            for (state_type, group_drag_style) in &self.as_stateless().group_drag_over_styles {
+                if let Some(group_bounds) = GroupBounds::get(&group_drag_style.group, cx) {
+                    if *state_type == drag.state_type
+                        && group_bounds.contains_point(&mouse_position)
+                    {
+                        style.refine(&group_drag_style.style);
+                    }
+                }
+            }
+
+            for (state_type, drag_over_style) in &self.as_stateless().drag_over_styles {
+                if *state_type == drag.state_type && bounds.contains_point(&mouse_position) {
+                    style.refine(drag_over_style);
+                }
+            }
+
+            cx.active_drag = Some(drag);
+        }
+
         if let Some(stateful) = self.as_stateful() {
             let active_state = element_state.active_state.lock();
             if active_state.group {
@@ -450,11 +498,27 @@ pub trait ElementInteraction<V: 'static + Send + Sync>: 'static + Send + Sync {
             .and_then(|group_hover| GroupBounds::get(&group_hover.group, cx));
 
         if let Some(group_bounds) = hover_group_bounds {
-            paint_hover_listener(group_bounds, cx);
+            let hovered = group_bounds.contains_point(&cx.mouse_position());
+            cx.on_mouse_event(move |_, event: &MouseMoveEvent, phase, cx| {
+                if phase == DispatchPhase::Capture {
+                    if group_bounds.contains_point(&event.position) != hovered {
+                        cx.notify();
+                    }
+                }
+            });
         }
 
-        if stateless.hover_style.is_some() {
-            paint_hover_listener(bounds, cx);
+        if stateless.hover_style.is_some()
+            || (cx.active_drag.is_some() && !stateless.drag_over_styles.is_empty())
+        {
+            let hovered = bounds.contains_point(&cx.mouse_position());
+            cx.on_mouse_event(move |_, event: &MouseMoveEvent, phase, cx| {
+                if phase == DispatchPhase::Capture {
+                    if bounds.contains_point(&event.position) != hovered {
+                        cx.notify();
+                    }
+                }
+            });
         }
 
         if let Some(stateful) = self.as_stateful() {
@@ -466,6 +530,7 @@ pub trait ElementInteraction<V: 'static + Send + Sync>: 'static + Send + Sync {
                 let mouse_down = pending_mouse_down.lock().clone();
                 if let Some(mouse_down) = mouse_down {
                     if let Some(drag_listener) = drag_listener {
+                        let active_state = element_state.active_state.clone();
                         cx.on_mouse_event(move |view_state, event: &MouseMoveEvent, phase, cx| {
                             if cx.active_drag.is_some() {
                                 if phase == DispatchPhase::Capture {
@@ -476,6 +541,7 @@ pub trait ElementInteraction<V: 'static + Send + Sync>: 'static + Send + Sync {
                             {
                                 let cursor_offset = event.position - bounds.origin;
                                 let any_drag = drag_listener(view_state, cursor_offset, cx);
+                                *active_state.lock() = ActiveState::default();
                                 cx.start_drag(any_drag);
                                 cx.stop_propagation();
                             }
@@ -570,30 +636,16 @@ pub trait ElementInteraction<V: 'static + Send + Sync>: 'static + Send + Sync {
     }
 }
 
-fn paint_hover_listener<V>(bounds: Bounds<Pixels>, cx: &mut ViewContext<V>)
-where
-    V: 'static + Send + Sync,
-{
-    let hovered = bounds.contains_point(&cx.mouse_position());
-    cx.on_mouse_event(move |_, event: &MouseMoveEvent, phase, cx| {
-        if phase == DispatchPhase::Capture {
-            if bounds.contains_point(&event.position) != hovered {
-                cx.notify();
-            }
-        }
-    });
-}
-
 #[derive(Deref, DerefMut)]
 pub struct StatefulInteraction<V: 'static + Send + Sync> {
     pub id: ElementId,
     #[deref]
     #[deref_mut]
     stateless: StatelessInteraction<V>,
-    pub click_listeners: SmallVec<[ClickListener<V>; 2]>,
-    pub(crate) drag_listener: Option<DragListener<V>>,
-    pub active_style: StyleRefinement,
-    pub group_active_style: Option<GroupStyle>,
+    click_listeners: SmallVec<[ClickListener<V>; 2]>,
+    active_style: StyleRefinement,
+    group_active_style: Option<GroupStyle>,
+    drag_listener: Option<DragListener<V>>,
 }
 
 impl<V> ElementInteraction<V> for StatefulInteraction<V>
@@ -642,6 +694,8 @@ pub struct StatelessInteraction<V> {
     pub key_listeners: SmallVec<[(TypeId, KeyListener<V>); 32]>,
     pub hover_style: StyleRefinement,
     pub group_hover_style: Option<GroupStyle>,
+    drag_over_styles: SmallVec<[(TypeId, StyleRefinement); 2]>,
+    group_drag_over_styles: SmallVec<[(TypeId, GroupStyle); 2]>,
 }
 
 impl<V> StatelessInteraction<V>
@@ -732,6 +786,8 @@ impl<V> Default for StatelessInteraction<V> {
             key_listeners: SmallVec::new(),
             hover_style: StyleRefinement::default(),
             group_hover_style: None,
+            drag_over_styles: SmallVec::new(),
+            group_drag_over_styles: SmallVec::new(),
         }
     }
 }

crates/ui2/src/components/tab.rs 🔗

@@ -125,6 +125,7 @@ impl<S: 'static + Send + Sync + Clone> Tab<S> {
             .on_drag(move |_view, _cx| {
                 Drag::new(drag_state.clone(), |view, cx| div().w_8().h_4().bg(red()))
             })
+            .drag_over::<TabDragState>(|d| d.bg(black()))
             .px_2()
             .py_0p5()
             .flex()
@@ -160,7 +161,7 @@ impl<S: 'static + Send + Sync + Clone> Tab<S> {
     }
 }
 
-use gpui2::{red, Drag, ElementId};
+use gpui2::{black, red, Drag, ElementId};
 #[cfg(feature = "stories")]
 pub use stories::*;