hoverable.rs

 1use crate::{
 2    element::{Element, Layout},
 3    layout_context::LayoutContext,
 4    paint_context::PaintContext,
 5    style::{StyleRefinement, Styleable},
 6};
 7use anyhow::Result;
 8use gpui::platform::MouseMovedEvent;
 9use refineable::Refineable;
10use std::{cell::Cell, marker::PhantomData};
11
12pub struct Hoverable<V: 'static, E: Element<V> + Styleable> {
13    hovered: Cell<bool>,
14    child_style: StyleRefinement,
15    hovered_style: StyleRefinement,
16    child: E,
17    view_type: PhantomData<V>,
18}
19
20pub fn hoverable<V, E: Element<V> + Styleable>(mut child: E) -> Hoverable<V, E> {
21    Hoverable {
22        hovered: Cell::new(false),
23        child_style: child.declared_style().clone(),
24        hovered_style: Default::default(),
25        child,
26        view_type: PhantomData,
27    }
28}
29
30impl<V, E: Element<V> + Styleable> Styleable for Hoverable<V, E> {
31    type Style = E::Style;
32
33    fn declared_style(&mut self) -> &mut crate::style::StyleRefinement {
34        self.child.declared_style()
35    }
36}
37
38impl<V: 'static, E: Element<V> + Styleable> Element<V> for Hoverable<V, E> {
39    type Layout = E::Layout;
40
41    fn layout(&mut self, view: &mut V, cx: &mut LayoutContext<V>) -> Result<Layout<V, Self::Layout>>
42    where
43        Self: Sized,
44    {
45        self.child.layout(view, cx)
46    }
47
48    fn paint(
49        &mut self,
50        view: &mut V,
51        layout: &mut Layout<V, Self::Layout>,
52        cx: &mut PaintContext<V>,
53    ) where
54        Self: Sized,
55    {
56        if self.hovered.get() {
57            // If hovered, refine the child's style with this element's style.
58            self.child.declared_style().refine(&self.hovered_style);
59        } else {
60            // Otherwise, set the child's style back to its original style.
61            *self.child.declared_style() = self.child_style.clone();
62        }
63
64        let bounds = layout.bounds(cx);
65        let order = layout.order(cx);
66        self.hovered.set(bounds.contains_point(cx.mouse_position()));
67        let was_hovered = self.hovered.clone();
68        cx.on_event(order, move |view, event: &MouseMovedEvent, cx| {
69            let is_hovered = bounds.contains_point(event.position);
70            if is_hovered != was_hovered.get() {
71                was_hovered.set(is_hovered);
72                cx.repaint();
73            }
74        });
75    }
76}