flex.rs

  1use std::{any::Any, cell::Cell, f32::INFINITY, ops::Range, rc::Rc};
  2
  3use crate::{
  4    json::{self, ToJson, Value},
  5    AnyElement, Axis, Element, ElementStateHandle, SceneBuilder, SizeConstraint, Vector2FExt, View,
  6    ViewContext,
  7};
  8use pathfinder_geometry::{
  9    rect::RectF,
 10    vector::{vec2f, Vector2F},
 11};
 12use serde_json::json;
 13
 14#[derive(Default)]
 15struct ScrollState {
 16    scroll_to: Cell<Option<usize>>,
 17    scroll_position: Cell<f32>,
 18}
 19
 20pub struct Flex<V: View> {
 21    axis: Axis,
 22    children: Vec<AnyElement<V>>,
 23    scroll_state: Option<(ElementStateHandle<Rc<ScrollState>>, usize)>,
 24    child_alignment: f32,
 25}
 26
 27impl<V: View> Flex<V> {
 28    pub fn new(axis: Axis) -> Self {
 29        Self {
 30            axis,
 31            children: Default::default(),
 32            scroll_state: None,
 33            child_alignment: -1.,
 34        }
 35    }
 36
 37    pub fn row() -> Self {
 38        Self::new(Axis::Horizontal)
 39    }
 40
 41    pub fn column() -> Self {
 42        Self::new(Axis::Vertical)
 43    }
 44
 45    /// Render children centered relative to the cross-axis of the parent flex.
 46    ///
 47    /// If this is a flex row, children will be centered vertically. If this is a
 48    /// flex column, children will be centered horizontally.
 49    pub fn align_children_center(mut self) -> Self {
 50        self.child_alignment = 0.;
 51        self
 52    }
 53
 54    pub fn scrollable<Tag>(
 55        mut self,
 56        element_id: usize,
 57        scroll_to: Option<usize>,
 58        cx: &mut ViewContext<V>,
 59    ) -> Self
 60    where
 61        Tag: 'static,
 62    {
 63        let scroll_state = cx.default_element_state::<Tag, Rc<ScrollState>>(element_id);
 64        scroll_state.read(cx).scroll_to.set(scroll_to);
 65        self.scroll_state = Some((scroll_state, cx.handle().id()));
 66        self
 67    }
 68
 69    fn layout_flex_children(
 70        &mut self,
 71        layout_expanded: bool,
 72        constraint: SizeConstraint,
 73        remaining_space: &mut f32,
 74        remaining_flex: &mut f32,
 75        cross_axis_max: &mut f32,
 76        view: &mut V,
 77        cx: &mut ViewContext<V>,
 78    ) {
 79        let cross_axis = self.axis.invert();
 80        for child in &mut self.children {
 81            if let Some(metadata) = child.metadata::<FlexParentData>() {
 82                if let Some((flex, expanded)) = metadata.flex {
 83                    if expanded != layout_expanded {
 84                        continue;
 85                    }
 86
 87                    let child_max = if *remaining_flex == 0.0 {
 88                        *remaining_space
 89                    } else {
 90                        let space_per_flex = *remaining_space / *remaining_flex;
 91                        space_per_flex * flex
 92                    };
 93                    let child_min = if expanded { child_max } else { 0. };
 94                    let child_constraint = match self.axis {
 95                        Axis::Horizontal => SizeConstraint::new(
 96                            vec2f(child_min, constraint.min.y()),
 97                            vec2f(child_max, constraint.max.y()),
 98                        ),
 99                        Axis::Vertical => SizeConstraint::new(
100                            vec2f(constraint.min.x(), child_min),
101                            vec2f(constraint.max.x(), child_max),
102                        ),
103                    };
104                    let child_size = child.layout(child_constraint, view, cx);
105                    *remaining_space -= child_size.along(self.axis);
106                    *remaining_flex -= flex;
107                    *cross_axis_max = cross_axis_max.max(child_size.along(cross_axis));
108                }
109            }
110        }
111    }
112}
113
114impl<V: View> Extend<AnyElement<V>> for Flex<V> {
115    fn extend<T: IntoIterator<Item = AnyElement<V>>>(&mut self, children: T) {
116        self.children.extend(children);
117    }
118}
119
120impl<V: View> Element<V> for Flex<V> {
121    type LayoutState = f32;
122    type PaintState = ();
123
124    fn layout(
125        &mut self,
126        constraint: SizeConstraint,
127        view: &mut V,
128        cx: &mut ViewContext<V>,
129    ) -> (Vector2F, Self::LayoutState) {
130        let mut total_flex = None;
131        let mut fixed_space = 0.0;
132        let mut contains_float = false;
133
134        let cross_axis = self.axis.invert();
135        let mut cross_axis_max: f32 = 0.0;
136        for child in &mut self.children {
137            let metadata = child.metadata::<FlexParentData>();
138            contains_float |= metadata.map_or(false, |metadata| metadata.float);
139
140            if let Some(flex) = metadata.and_then(|metadata| metadata.flex.map(|(flex, _)| flex)) {
141                *total_flex.get_or_insert(0.) += flex;
142            } else {
143                let child_constraint = match self.axis {
144                    Axis::Horizontal => SizeConstraint::new(
145                        vec2f(0.0, constraint.min.y()),
146                        vec2f(INFINITY, constraint.max.y()),
147                    ),
148                    Axis::Vertical => SizeConstraint::new(
149                        vec2f(constraint.min.x(), 0.0),
150                        vec2f(constraint.max.x(), INFINITY),
151                    ),
152                };
153                let size = child.layout(child_constraint, view, cx);
154                fixed_space += size.along(self.axis);
155                cross_axis_max = cross_axis_max.max(size.along(cross_axis));
156            }
157        }
158
159        let mut remaining_space = constraint.max_along(self.axis) - fixed_space;
160        let mut size = if let Some(mut remaining_flex) = total_flex {
161            if remaining_space.is_infinite() {
162                panic!("flex contains flexible children but has an infinite constraint along the flex axis");
163            }
164
165            self.layout_flex_children(
166                false,
167                constraint,
168                &mut remaining_space,
169                &mut remaining_flex,
170                &mut cross_axis_max,
171                view,
172                cx,
173            );
174            self.layout_flex_children(
175                true,
176                constraint,
177                &mut remaining_space,
178                &mut remaining_flex,
179                &mut cross_axis_max,
180                view,
181                cx,
182            );
183
184            match self.axis {
185                Axis::Horizontal => vec2f(constraint.max.x() - remaining_space, cross_axis_max),
186                Axis::Vertical => vec2f(cross_axis_max, constraint.max.y() - remaining_space),
187            }
188        } else {
189            match self.axis {
190                Axis::Horizontal => vec2f(fixed_space, cross_axis_max),
191                Axis::Vertical => vec2f(cross_axis_max, fixed_space),
192            }
193        };
194
195        if contains_float {
196            match self.axis {
197                Axis::Horizontal => size.set_x(size.x().max(constraint.max.x())),
198                Axis::Vertical => size.set_y(size.y().max(constraint.max.y())),
199            }
200        }
201
202        if constraint.min.x().is_finite() {
203            size.set_x(size.x().max(constraint.min.x()));
204        }
205        if constraint.min.y().is_finite() {
206            size.set_y(size.y().max(constraint.min.y()));
207        }
208
209        if size.x() > constraint.max.x() {
210            size.set_x(constraint.max.x());
211        }
212        if size.y() > constraint.max.y() {
213            size.set_y(constraint.max.y());
214        }
215
216        if let Some(scroll_state) = self.scroll_state.as_ref() {
217            scroll_state.0.update(cx, |scroll_state, _| {
218                if let Some(scroll_to) = scroll_state.scroll_to.take() {
219                    let visible_start = scroll_state.scroll_position.get();
220                    let visible_end = visible_start + size.along(self.axis);
221                    if let Some(child) = self.children.get(scroll_to) {
222                        let child_start: f32 = self.children[..scroll_to]
223                            .iter()
224                            .map(|c| c.size().along(self.axis))
225                            .sum();
226                        let child_end = child_start + child.size().along(self.axis);
227                        if child_start < visible_start {
228                            scroll_state.scroll_position.set(child_start);
229                        } else if child_end > visible_end {
230                            scroll_state
231                                .scroll_position
232                                .set(child_end - size.along(self.axis));
233                        }
234                    }
235                }
236
237                scroll_state.scroll_position.set(
238                    scroll_state
239                        .scroll_position
240                        .get()
241                        .min(-remaining_space)
242                        .max(0.),
243                );
244            });
245        }
246
247        (size, remaining_space)
248    }
249
250    fn paint(
251        &mut self,
252        scene: &mut SceneBuilder,
253        bounds: RectF,
254        visible_bounds: RectF,
255        remaining_space: &mut Self::LayoutState,
256        view: &mut V,
257        cx: &mut ViewContext<V>,
258    ) -> Self::PaintState {
259        let visible_bounds = bounds.intersection(visible_bounds).unwrap_or_default();
260
261        let mut remaining_space = *remaining_space;
262        let overflowing = remaining_space < 0.;
263        if overflowing {
264            scene.push_layer(Some(visible_bounds));
265        }
266
267        if let Some(scroll_state) = &self.scroll_state {
268            scene.push_mouse_region(
269                crate::MouseRegion::new::<Self>(scroll_state.1, 0, bounds)
270                    .on_scroll({
271                        let scroll_state = scroll_state.0.read(cx).clone();
272                        let axis = self.axis;
273                        move |e, _: &mut V, cx| {
274                            if remaining_space < 0. {
275                                let scroll_delta = e.delta.raw();
276
277                                let mut delta = match axis {
278                                    Axis::Horizontal => {
279                                        if scroll_delta.x().abs() >= scroll_delta.y().abs() {
280                                            scroll_delta.x()
281                                        } else {
282                                            scroll_delta.y()
283                                        }
284                                    }
285                                    Axis::Vertical => scroll_delta.y(),
286                                };
287                                if !e.delta.precise() {
288                                    delta *= 20.;
289                                }
290
291                                scroll_state
292                                    .scroll_position
293                                    .set(scroll_state.scroll_position.get() - delta);
294
295                                cx.notify();
296                            } else {
297                                cx.propagate_event();
298                            }
299                        }
300                    })
301                    .on_move(|_, _: &mut V, _| { /* Capture move events */ }),
302            )
303        }
304
305        let mut child_origin = bounds.origin();
306        if let Some(scroll_state) = self.scroll_state.as_ref() {
307            let scroll_position = scroll_state.0.read(cx).scroll_position.get();
308            match self.axis {
309                Axis::Horizontal => child_origin.set_x(child_origin.x() - scroll_position),
310                Axis::Vertical => child_origin.set_y(child_origin.y() - scroll_position),
311            }
312        }
313
314        for child in &mut self.children {
315            if remaining_space > 0. {
316                if let Some(metadata) = child.metadata::<FlexParentData>() {
317                    if metadata.float {
318                        match self.axis {
319                            Axis::Horizontal => child_origin += vec2f(remaining_space, 0.0),
320                            Axis::Vertical => child_origin += vec2f(0.0, remaining_space),
321                        }
322                        remaining_space = 0.;
323                    }
324                }
325            }
326
327            // We use the child_alignment f32 to determine a point along the cross axis of the
328            // overall flex element and each child. We then align these points. So 0 would center
329            // each child relative to the overall height/width of the flex. -1 puts children at
330            // the start. 1 puts children at the end.
331            let aligned_child_origin = {
332                let cross_axis = self.axis.invert();
333                let my_center = bounds.size().along(cross_axis) / 2.;
334                let my_target = my_center + my_center * self.child_alignment;
335
336                let child_center = child.size().along(cross_axis) / 2.;
337                let child_target = child_center + child_center * self.child_alignment;
338
339                let mut aligned_child_origin = child_origin;
340                match self.axis {
341                    Axis::Horizontal => aligned_child_origin
342                        .set_y(aligned_child_origin.y() - (child_target - my_target)),
343                    Axis::Vertical => aligned_child_origin
344                        .set_x(aligned_child_origin.x() - (child_target - my_target)),
345                }
346
347                aligned_child_origin
348            };
349
350            child.paint(scene, aligned_child_origin, visible_bounds, view, cx);
351
352            match self.axis {
353                Axis::Horizontal => child_origin += vec2f(child.size().x(), 0.0),
354                Axis::Vertical => child_origin += vec2f(0.0, child.size().y()),
355            }
356        }
357
358        if overflowing {
359            scene.pop_layer();
360        }
361    }
362
363    fn rect_for_text_range(
364        &self,
365        range_utf16: Range<usize>,
366        _: RectF,
367        _: RectF,
368        _: &Self::LayoutState,
369        _: &Self::PaintState,
370        view: &V,
371        cx: &ViewContext<V>,
372    ) -> Option<RectF> {
373        self.children
374            .iter()
375            .find_map(|child| child.rect_for_text_range(range_utf16.clone(), view, cx))
376    }
377
378    fn debug(
379        &self,
380        bounds: RectF,
381        _: &Self::LayoutState,
382        _: &Self::PaintState,
383        view: &V,
384        cx: &ViewContext<V>,
385    ) -> json::Value {
386        json!({
387            "type": "Flex",
388            "bounds": bounds.to_json(),
389            "axis": self.axis.to_json(),
390            "children": self.children.iter().map(|child| child.debug(view, cx)).collect::<Vec<json::Value>>()
391        })
392    }
393}
394
395struct FlexParentData {
396    flex: Option<(f32, bool)>,
397    float: bool,
398}
399
400pub struct FlexItem<V: View> {
401    metadata: FlexParentData,
402    child: AnyElement<V>,
403}
404
405impl<V: View> FlexItem<V> {
406    pub fn new(child: impl Element<V>) -> Self {
407        FlexItem {
408            metadata: FlexParentData {
409                flex: None,
410                float: false,
411            },
412            child: child.into_any(),
413        }
414    }
415
416    pub fn flex(mut self, flex: f32, expanded: bool) -> Self {
417        self.metadata.flex = Some((flex, expanded));
418        self
419    }
420
421    pub fn float(mut self) -> Self {
422        self.metadata.float = true;
423        self
424    }
425}
426
427impl<V: View> Element<V> for FlexItem<V> {
428    type LayoutState = ();
429    type PaintState = ();
430
431    fn layout(
432        &mut self,
433        constraint: SizeConstraint,
434        view: &mut V,
435        cx: &mut ViewContext<V>,
436    ) -> (Vector2F, Self::LayoutState) {
437        let size = self.child.layout(constraint, view, cx);
438        (size, ())
439    }
440
441    fn paint(
442        &mut self,
443        scene: &mut SceneBuilder,
444        bounds: RectF,
445        visible_bounds: RectF,
446        _: &mut Self::LayoutState,
447        view: &mut V,
448        cx: &mut ViewContext<V>,
449    ) -> Self::PaintState {
450        self.child
451            .paint(scene, bounds.origin(), visible_bounds, view, cx)
452    }
453
454    fn rect_for_text_range(
455        &self,
456        range_utf16: Range<usize>,
457        _: RectF,
458        _: RectF,
459        _: &Self::LayoutState,
460        _: &Self::PaintState,
461        view: &V,
462        cx: &ViewContext<V>,
463    ) -> Option<RectF> {
464        self.child.rect_for_text_range(range_utf16, view, cx)
465    }
466
467    fn metadata(&self) -> Option<&dyn Any> {
468        Some(&self.metadata)
469    }
470
471    fn debug(
472        &self,
473        _: RectF,
474        _: &Self::LayoutState,
475        _: &Self::PaintState,
476        view: &V,
477        cx: &ViewContext<V>,
478    ) -> Value {
479        json!({
480            "type": "Flexible",
481            "flex": self.metadata.flex,
482            "child": self.child.debug(view, cx)
483        })
484    }
485}