flex.rs

  1use std::{any::Any, cell::Cell, f32::INFINITY, ops::Range, rc::Rc};
  2
  3use crate::{
  4    json::{self, ToJson, Value},
  5    AnyElement, Axis, Element, ElementStateHandle, SizeConstraint, Vector2FExt, ViewContext,
  6};
  7use pathfinder_geometry::{
  8    rect::RectF,
  9    vector::{vec2f, Vector2F},
 10};
 11use serde_json::json;
 12
 13#[derive(Default)]
 14struct ScrollState {
 15    scroll_to: Cell<Option<usize>>,
 16    scroll_position: Cell<f32>,
 17}
 18
 19pub struct Flex<V> {
 20    axis: Axis,
 21    children: Vec<AnyElement<V>>,
 22    scroll_state: Option<(ElementStateHandle<Rc<ScrollState>>, usize)>,
 23    child_alignment: f32,
 24    spacing: f32,
 25}
 26
 27impl<V: 'static> Flex<V> {
 28    pub fn new(axis: Axis) -> Self {
 29        Self {
 30            axis,
 31            children: Default::default(),
 32            scroll_state: None,
 33            child_alignment: -1.,
 34            spacing: 0.,
 35        }
 36    }
 37
 38    pub fn row() -> Self {
 39        Self::new(Axis::Horizontal)
 40    }
 41
 42    pub fn column() -> Self {
 43        Self::new(Axis::Vertical)
 44    }
 45
 46    /// Render children centered relative to the cross-axis of the parent flex.
 47    ///
 48    /// If this is a flex row, children will be centered vertically. If this is a
 49    /// flex column, children will be centered horizontally.
 50    pub fn align_children_center(mut self) -> Self {
 51        self.child_alignment = 0.;
 52        self
 53    }
 54
 55    pub fn with_spacing(mut self, spacing: f32) -> Self {
 56        self.spacing = spacing;
 57        self
 58    }
 59
 60    pub fn scrollable<Tag>(
 61        mut self,
 62        element_id: usize,
 63        scroll_to: Option<usize>,
 64        cx: &mut ViewContext<V>,
 65    ) -> Self
 66    where
 67        Tag: 'static,
 68    {
 69        let scroll_state = cx.default_element_state::<Tag, Rc<ScrollState>>(element_id);
 70        scroll_state.read(cx).scroll_to.set(scroll_to);
 71        self.scroll_state = Some((scroll_state, cx.handle().id()));
 72        self
 73    }
 74
 75    pub fn is_empty(&self) -> bool {
 76        self.children.is_empty()
 77    }
 78
 79    fn layout_flex_children(
 80        &mut self,
 81        layout_expanded: bool,
 82        constraint: SizeConstraint,
 83        remaining_space: &mut f32,
 84        remaining_flex: &mut f32,
 85        cross_axis_max: &mut f32,
 86        view: &mut V,
 87        cx: &mut ViewContext<V>,
 88    ) {
 89        let cross_axis = self.axis.invert();
 90        for child in self.children.iter_mut() {
 91            if let Some(metadata) = child.metadata::<FlexParentData>() {
 92                if let Some((flex, expanded)) = metadata.flex {
 93                    if expanded != layout_expanded {
 94                        continue;
 95                    }
 96
 97                    let child_max = if *remaining_flex == 0.0 {
 98                        *remaining_space
 99                    } else {
100                        let space_per_flex = *remaining_space / *remaining_flex;
101                        space_per_flex * flex
102                    };
103                    let child_min = if expanded { child_max } else { 0. };
104                    let child_constraint = match self.axis {
105                        Axis::Horizontal => SizeConstraint::new(
106                            vec2f(child_min, constraint.min.y()),
107                            vec2f(child_max, constraint.max.y()),
108                        ),
109                        Axis::Vertical => SizeConstraint::new(
110                            vec2f(constraint.min.x(), child_min),
111                            vec2f(constraint.max.x(), child_max),
112                        ),
113                    };
114                    let child_size = child.layout(child_constraint, view, cx);
115                    *remaining_space -= child_size.along(self.axis);
116                    *remaining_flex -= flex;
117                    *cross_axis_max = cross_axis_max.max(child_size.along(cross_axis));
118                }
119            }
120        }
121    }
122}
123
124impl<V> Extend<AnyElement<V>> for Flex<V> {
125    fn extend<T: IntoIterator<Item = AnyElement<V>>>(&mut self, children: T) {
126        self.children.extend(children);
127    }
128}
129
130impl<V: 'static> Element<V> for Flex<V> {
131    type LayoutState = f32;
132    type PaintState = ();
133
134    fn layout(
135        &mut self,
136        constraint: SizeConstraint,
137        view: &mut V,
138        cx: &mut ViewContext<V>,
139    ) -> (Vector2F, Self::LayoutState) {
140        let mut total_flex = None;
141        let mut fixed_space = self.children.len().saturating_sub(1) as f32 * self.spacing;
142        let mut contains_float = false;
143
144        let cross_axis = self.axis.invert();
145        let mut cross_axis_max: f32 = 0.0;
146        for child in self.children.iter_mut() {
147            let metadata = child.metadata::<FlexParentData>();
148            contains_float |= metadata.map_or(false, |metadata| metadata.float);
149
150            if let Some(flex) = metadata.and_then(|metadata| metadata.flex.map(|(flex, _)| flex)) {
151                *total_flex.get_or_insert(0.) += flex;
152            } else {
153                let child_constraint = match self.axis {
154                    Axis::Horizontal => SizeConstraint::new(
155                        vec2f(0.0, constraint.min.y()),
156                        vec2f(INFINITY, constraint.max.y()),
157                    ),
158                    Axis::Vertical => SizeConstraint::new(
159                        vec2f(constraint.min.x(), 0.0),
160                        vec2f(constraint.max.x(), INFINITY),
161                    ),
162                };
163                let size = child.layout(child_constraint, view, cx);
164                fixed_space += size.along(self.axis);
165                cross_axis_max = cross_axis_max.max(size.along(cross_axis));
166            }
167        }
168
169        let mut remaining_space = constraint.max_along(self.axis) - fixed_space;
170        let mut size = if let Some(mut remaining_flex) = total_flex {
171            if remaining_space.is_infinite() {
172                panic!("flex contains flexible children but has an infinite constraint along the flex axis");
173            }
174
175            self.layout_flex_children(
176                false,
177                constraint,
178                &mut remaining_space,
179                &mut remaining_flex,
180                &mut cross_axis_max,
181                view,
182                cx,
183            );
184            self.layout_flex_children(
185                true,
186                constraint,
187                &mut remaining_space,
188                &mut remaining_flex,
189                &mut cross_axis_max,
190                view,
191                cx,
192            );
193
194            match self.axis {
195                Axis::Horizontal => vec2f(constraint.max.x() - remaining_space, cross_axis_max),
196                Axis::Vertical => vec2f(cross_axis_max, constraint.max.y() - remaining_space),
197            }
198        } else {
199            match self.axis {
200                Axis::Horizontal => vec2f(fixed_space, cross_axis_max),
201                Axis::Vertical => vec2f(cross_axis_max, fixed_space),
202            }
203        };
204
205        if contains_float {
206            match self.axis {
207                Axis::Horizontal => size.set_x(size.x().max(constraint.max.x())),
208                Axis::Vertical => size.set_y(size.y().max(constraint.max.y())),
209            }
210        }
211
212        if constraint.min.x().is_finite() {
213            size.set_x(size.x().max(constraint.min.x()));
214        }
215        if constraint.min.y().is_finite() {
216            size.set_y(size.y().max(constraint.min.y()));
217        }
218
219        if size.x() > constraint.max.x() {
220            size.set_x(constraint.max.x());
221        }
222        if size.y() > constraint.max.y() {
223            size.set_y(constraint.max.y());
224        }
225
226        if let Some(scroll_state) = self.scroll_state.as_ref() {
227            scroll_state.0.update(cx, |scroll_state, _| {
228                if let Some(scroll_to) = scroll_state.scroll_to.take() {
229                    let visible_start = scroll_state.scroll_position.get();
230                    let visible_end = visible_start + size.along(self.axis);
231                    if let Some(child) = self.children.get(scroll_to) {
232                        let child_start: f32 = self.children[..scroll_to]
233                            .iter()
234                            .map(|c| c.size().along(self.axis))
235                            .sum();
236                        let child_end = child_start + child.size().along(self.axis);
237                        if child_start < visible_start {
238                            scroll_state.scroll_position.set(child_start);
239                        } else if child_end > visible_end {
240                            scroll_state
241                                .scroll_position
242                                .set(child_end - size.along(self.axis));
243                        }
244                    }
245                }
246
247                scroll_state.scroll_position.set(
248                    scroll_state
249                        .scroll_position
250                        .get()
251                        .min(-remaining_space)
252                        .max(0.),
253                );
254            });
255        }
256
257        (size, remaining_space)
258    }
259
260    fn paint(
261        &mut self,
262        bounds: RectF,
263        visible_bounds: RectF,
264        remaining_space: &mut Self::LayoutState,
265        view: &mut V,
266        cx: &mut ViewContext<V>,
267    ) -> Self::PaintState {
268        let visible_bounds = bounds.intersection(visible_bounds).unwrap_or_default();
269
270        let mut remaining_space = *remaining_space;
271        let overflowing = remaining_space < 0.;
272        if overflowing {
273            cx.scene().push_layer(Some(visible_bounds));
274        }
275
276        if let Some((scroll_state, id)) = &self.scroll_state {
277            let scroll_state = scroll_state.read(cx).clone();
278            cx.scene().push_mouse_region(
279                crate::MouseRegion::new::<Self>(*id, 0, bounds)
280                    .on_scroll({
281                        let axis = self.axis;
282                        move |e, _: &mut V, cx| {
283                            if remaining_space < 0. {
284                                let scroll_delta = e.delta.raw();
285
286                                let mut delta = match axis {
287                                    Axis::Horizontal => {
288                                        if scroll_delta.x().abs() >= scroll_delta.y().abs() {
289                                            scroll_delta.x()
290                                        } else {
291                                            scroll_delta.y()
292                                        }
293                                    }
294                                    Axis::Vertical => scroll_delta.y(),
295                                };
296                                if !e.delta.precise() {
297                                    delta *= 20.;
298                                }
299
300                                scroll_state
301                                    .scroll_position
302                                    .set(scroll_state.scroll_position.get() - delta);
303
304                                cx.notify();
305                            } else {
306                                cx.propagate_event();
307                            }
308                        }
309                    })
310                    .on_move(|_, _: &mut V, _| { /* Capture move events */ }),
311            )
312        }
313
314        let mut child_origin = bounds.origin();
315        if let Some(scroll_state) = self.scroll_state.as_ref() {
316            let scroll_position = scroll_state.0.read(cx).scroll_position.get();
317            match self.axis {
318                Axis::Horizontal => child_origin.set_x(child_origin.x() - scroll_position),
319                Axis::Vertical => child_origin.set_y(child_origin.y() - scroll_position),
320            }
321        }
322
323        for child in self.children.iter_mut() {
324            if remaining_space > 0. {
325                if let Some(metadata) = child.metadata::<FlexParentData>() {
326                    if metadata.float {
327                        match self.axis {
328                            Axis::Horizontal => child_origin += vec2f(remaining_space, 0.0),
329                            Axis::Vertical => child_origin += vec2f(0.0, remaining_space),
330                        }
331                        remaining_space = 0.;
332                    }
333                }
334            }
335
336            // We use the child_alignment f32 to determine a point along the cross axis of the
337            // overall flex element and each child. We then align these points. So 0 would center
338            // each child relative to the overall height/width of the flex. -1 puts children at
339            // the start. 1 puts children at the end.
340            let aligned_child_origin = {
341                let cross_axis = self.axis.invert();
342                let my_center = bounds.size().along(cross_axis) / 2.;
343                let my_target = my_center + my_center * self.child_alignment;
344
345                let child_center = child.size().along(cross_axis) / 2.;
346                let child_target = child_center + child_center * self.child_alignment;
347
348                let mut aligned_child_origin = child_origin;
349                match self.axis {
350                    Axis::Horizontal => aligned_child_origin
351                        .set_y(aligned_child_origin.y() - (child_target - my_target)),
352                    Axis::Vertical => aligned_child_origin
353                        .set_x(aligned_child_origin.x() - (child_target - my_target)),
354                }
355
356                aligned_child_origin
357            };
358
359            child.paint(aligned_child_origin, visible_bounds, view, cx);
360
361            match self.axis {
362                Axis::Horizontal => child_origin += vec2f(child.size().x() + self.spacing, 0.0),
363                Axis::Vertical => child_origin += vec2f(0.0, child.size().y() + self.spacing),
364            }
365        }
366
367        if overflowing {
368            cx.scene().pop_layer();
369        }
370    }
371
372    fn rect_for_text_range(
373        &self,
374        range_utf16: Range<usize>,
375        _: RectF,
376        _: RectF,
377        _: &Self::LayoutState,
378        _: &Self::PaintState,
379        view: &V,
380        cx: &ViewContext<V>,
381    ) -> Option<RectF> {
382        self.children
383            .iter()
384            .find_map(|child| child.rect_for_text_range(range_utf16.clone(), view, cx))
385    }
386
387    fn debug(
388        &self,
389        bounds: RectF,
390        _: &Self::LayoutState,
391        _: &Self::PaintState,
392        view: &V,
393        cx: &ViewContext<V>,
394    ) -> json::Value {
395        json!({
396            "type": "Flex",
397            "bounds": bounds.to_json(),
398            "axis": self.axis.to_json(),
399            "children": self.children.iter().map(|child| child.debug(view, cx)).collect::<Vec<json::Value>>()
400        })
401    }
402}
403
404struct FlexParentData {
405    flex: Option<(f32, bool)>,
406    float: bool,
407}
408
409pub struct FlexItem<V> {
410    metadata: FlexParentData,
411    child: AnyElement<V>,
412}
413
414impl<V: 'static> FlexItem<V> {
415    pub fn new(child: impl Element<V>) -> Self {
416        FlexItem {
417            metadata: FlexParentData {
418                flex: None,
419                float: false,
420            },
421            child: child.into_any(),
422        }
423    }
424
425    pub fn flex(mut self, flex: f32, expanded: bool) -> Self {
426        self.metadata.flex = Some((flex, expanded));
427        self
428    }
429
430    pub fn float(mut self) -> Self {
431        self.metadata.float = true;
432        self
433    }
434}
435
436impl<V: 'static> Element<V> for FlexItem<V> {
437    type LayoutState = ();
438    type PaintState = ();
439
440    fn layout(
441        &mut self,
442        constraint: SizeConstraint,
443        view: &mut V,
444        cx: &mut ViewContext<V>,
445    ) -> (Vector2F, Self::LayoutState) {
446        let size = self.child.layout(constraint, view, cx);
447        (size, ())
448    }
449
450    fn paint(
451        &mut self,
452        bounds: RectF,
453        visible_bounds: RectF,
454        _: &mut Self::LayoutState,
455        view: &mut V,
456        cx: &mut ViewContext<V>,
457    ) -> Self::PaintState {
458        self.child.paint(bounds.origin(), visible_bounds, view, cx)
459    }
460
461    fn rect_for_text_range(
462        &self,
463        range_utf16: Range<usize>,
464        _: RectF,
465        _: RectF,
466        _: &Self::LayoutState,
467        _: &Self::PaintState,
468        view: &V,
469        cx: &ViewContext<V>,
470    ) -> Option<RectF> {
471        self.child.rect_for_text_range(range_utf16, view, cx)
472    }
473
474    fn metadata(&self) -> Option<&dyn Any> {
475        Some(&self.metadata)
476    }
477
478    fn debug(
479        &self,
480        _: RectF,
481        _: &Self::LayoutState,
482        _: &Self::PaintState,
483        view: &V,
484        cx: &ViewContext<V>,
485    ) -> Value {
486        json!({
487            "type": "Flexible",
488            "flex": self.metadata.flex,
489            "child": self.child.debug(view, cx)
490        })
491    }
492}