flex.rs

  1use std::{any::Any, cell::Cell, f32::INFINITY, ops::Range, rc::Rc};
  2
  3use crate::{
  4    json::{self, ToJson, Value},
  5    AnyElement, Axis, Element, ElementStateHandle, LayoutContext, PaintContext, SceneBuilder,
  6    SizeConstraint, Vector2FExt, ViewContext,
  7};
  8use pathfinder_geometry::{
  9    rect::RectF,
 10    vector::{vec2f, Vector2F},
 11};
 12use serde_json::json;
 13
 14#[derive(Default)]
 15struct ScrollState {
 16    scroll_to: Cell<Option<usize>>,
 17    scroll_position: Cell<f32>,
 18}
 19
 20pub struct Flex<V> {
 21    axis: Axis,
 22    children: Vec<AnyElement<V>>,
 23    scroll_state: Option<(ElementStateHandle<Rc<ScrollState>>, usize)>,
 24    child_alignment: f32,
 25}
 26
 27impl<V: 'static> Flex<V> {
 28    pub fn new(axis: Axis) -> Self {
 29        Self {
 30            axis,
 31            children: Default::default(),
 32            scroll_state: None,
 33            child_alignment: -1.,
 34        }
 35    }
 36
 37    pub fn row() -> Self {
 38        Self::new(Axis::Horizontal)
 39    }
 40
 41    pub fn column() -> Self {
 42        Self::new(Axis::Vertical)
 43    }
 44
 45    /// Render children centered relative to the cross-axis of the parent flex.
 46    ///
 47    /// If this is a flex row, children will be centered vertically. If this is a
 48    /// flex column, children will be centered horizontally.
 49    pub fn align_children_center(mut self) -> Self {
 50        self.child_alignment = 0.;
 51        self
 52    }
 53
 54    pub fn scrollable<Tag>(
 55        mut self,
 56        element_id: usize,
 57        scroll_to: Option<usize>,
 58        cx: &mut ViewContext<V>,
 59    ) -> Self
 60    where
 61        Tag: 'static,
 62    {
 63        let scroll_state = cx.default_element_state::<Tag, Rc<ScrollState>>(element_id);
 64        scroll_state.read(cx).scroll_to.set(scroll_to);
 65        self.scroll_state = Some((scroll_state, cx.handle().id()));
 66        self
 67    }
 68
 69    pub fn is_empty(&self) -> bool {
 70        self.children.is_empty()
 71    }
 72
 73    fn layout_flex_children(
 74        &mut self,
 75        layout_expanded: bool,
 76        constraint: SizeConstraint,
 77        remaining_space: &mut f32,
 78        remaining_flex: &mut f32,
 79        cross_axis_max: &mut f32,
 80        view: &mut V,
 81        cx: &mut LayoutContext<V>,
 82    ) {
 83        let cross_axis = self.axis.invert();
 84        for child in &mut self.children {
 85            if let Some(metadata) = child.metadata::<FlexParentData>() {
 86                if let Some((flex, expanded)) = metadata.flex {
 87                    if expanded != layout_expanded {
 88                        continue;
 89                    }
 90
 91                    let child_max = if *remaining_flex == 0.0 {
 92                        *remaining_space
 93                    } else {
 94                        let space_per_flex = *remaining_space / *remaining_flex;
 95                        space_per_flex * flex
 96                    };
 97                    let child_min = if expanded { child_max } else { 0. };
 98                    let child_constraint = match self.axis {
 99                        Axis::Horizontal => SizeConstraint::new(
100                            vec2f(child_min, constraint.min.y()),
101                            vec2f(child_max, constraint.max.y()),
102                        ),
103                        Axis::Vertical => SizeConstraint::new(
104                            vec2f(constraint.min.x(), child_min),
105                            vec2f(constraint.max.x(), child_max),
106                        ),
107                    };
108                    let child_size = child.layout(child_constraint, view, cx);
109                    *remaining_space -= child_size.along(self.axis);
110                    *remaining_flex -= flex;
111                    *cross_axis_max = cross_axis_max.max(child_size.along(cross_axis));
112                }
113            }
114        }
115    }
116}
117
118impl<V> Extend<AnyElement<V>> for Flex<V> {
119    fn extend<T: IntoIterator<Item = AnyElement<V>>>(&mut self, children: T) {
120        self.children.extend(children);
121    }
122}
123
124impl<V: 'static> Element<V> for Flex<V> {
125    type LayoutState = f32;
126    type PaintState = ();
127
128    fn layout(
129        &mut self,
130        constraint: SizeConstraint,
131        view: &mut V,
132        cx: &mut LayoutContext<V>,
133    ) -> (Vector2F, Self::LayoutState) {
134        let mut total_flex = None;
135        let mut fixed_space = 0.0;
136        let mut contains_float = false;
137
138        let cross_axis = self.axis.invert();
139        let mut cross_axis_max: f32 = 0.0;
140        for child in &mut self.children {
141            let metadata = child.metadata::<FlexParentData>();
142            contains_float |= metadata.map_or(false, |metadata| metadata.float);
143
144            if let Some(flex) = metadata.and_then(|metadata| metadata.flex.map(|(flex, _)| flex)) {
145                *total_flex.get_or_insert(0.) += flex;
146            } else {
147                let child_constraint = match self.axis {
148                    Axis::Horizontal => SizeConstraint::new(
149                        vec2f(0.0, constraint.min.y()),
150                        vec2f(INFINITY, constraint.max.y()),
151                    ),
152                    Axis::Vertical => SizeConstraint::new(
153                        vec2f(constraint.min.x(), 0.0),
154                        vec2f(constraint.max.x(), INFINITY),
155                    ),
156                };
157                let size = child.layout(child_constraint, view, cx);
158                fixed_space += size.along(self.axis);
159                cross_axis_max = cross_axis_max.max(size.along(cross_axis));
160            }
161        }
162
163        let mut remaining_space = constraint.max_along(self.axis) - fixed_space;
164        let mut size = if let Some(mut remaining_flex) = total_flex {
165            if remaining_space.is_infinite() {
166                panic!("flex contains flexible children but has an infinite constraint along the flex axis");
167            }
168
169            self.layout_flex_children(
170                false,
171                constraint,
172                &mut remaining_space,
173                &mut remaining_flex,
174                &mut cross_axis_max,
175                view,
176                cx,
177            );
178            self.layout_flex_children(
179                true,
180                constraint,
181                &mut remaining_space,
182                &mut remaining_flex,
183                &mut cross_axis_max,
184                view,
185                cx,
186            );
187
188            match self.axis {
189                Axis::Horizontal => vec2f(constraint.max.x() - remaining_space, cross_axis_max),
190                Axis::Vertical => vec2f(cross_axis_max, constraint.max.y() - remaining_space),
191            }
192        } else {
193            match self.axis {
194                Axis::Horizontal => vec2f(fixed_space, cross_axis_max),
195                Axis::Vertical => vec2f(cross_axis_max, fixed_space),
196            }
197        };
198
199        if contains_float {
200            match self.axis {
201                Axis::Horizontal => size.set_x(size.x().max(constraint.max.x())),
202                Axis::Vertical => size.set_y(size.y().max(constraint.max.y())),
203            }
204        }
205
206        if constraint.min.x().is_finite() {
207            size.set_x(size.x().max(constraint.min.x()));
208        }
209        if constraint.min.y().is_finite() {
210            size.set_y(size.y().max(constraint.min.y()));
211        }
212
213        if size.x() > constraint.max.x() {
214            size.set_x(constraint.max.x());
215        }
216        if size.y() > constraint.max.y() {
217            size.set_y(constraint.max.y());
218        }
219
220        if let Some(scroll_state) = self.scroll_state.as_ref() {
221            scroll_state.0.update(cx.view_context(), |scroll_state, _| {
222                if let Some(scroll_to) = scroll_state.scroll_to.take() {
223                    let visible_start = scroll_state.scroll_position.get();
224                    let visible_end = visible_start + size.along(self.axis);
225                    if let Some(child) = self.children.get(scroll_to) {
226                        let child_start: f32 = self.children[..scroll_to]
227                            .iter()
228                            .map(|c| c.size().along(self.axis))
229                            .sum();
230                        let child_end = child_start + child.size().along(self.axis);
231                        if child_start < visible_start {
232                            scroll_state.scroll_position.set(child_start);
233                        } else if child_end > visible_end {
234                            scroll_state
235                                .scroll_position
236                                .set(child_end - size.along(self.axis));
237                        }
238                    }
239                }
240
241                scroll_state.scroll_position.set(
242                    scroll_state
243                        .scroll_position
244                        .get()
245                        .min(-remaining_space)
246                        .max(0.),
247                );
248            });
249        }
250
251        (size, remaining_space)
252    }
253
254    fn paint(
255        &mut self,
256        scene: &mut SceneBuilder,
257        bounds: RectF,
258        visible_bounds: RectF,
259        remaining_space: &mut Self::LayoutState,
260        view: &mut V,
261        cx: &mut PaintContext<V>,
262    ) -> Self::PaintState {
263        let visible_bounds = bounds.intersection(visible_bounds).unwrap_or_default();
264
265        let mut remaining_space = *remaining_space;
266        let overflowing = remaining_space < 0.;
267        if overflowing {
268            scene.push_layer(Some(visible_bounds));
269        }
270
271        if let Some(scroll_state) = &self.scroll_state {
272            scene.push_mouse_region(
273                crate::MouseRegion::new::<Self>(scroll_state.1, 0, bounds)
274                    .on_scroll({
275                        let scroll_state = scroll_state.0.read(cx).clone();
276                        let axis = self.axis;
277                        move |e, _: &mut V, cx| {
278                            if remaining_space < 0. {
279                                let scroll_delta = e.delta.raw();
280
281                                let mut delta = match axis {
282                                    Axis::Horizontal => {
283                                        if scroll_delta.x().abs() >= scroll_delta.y().abs() {
284                                            scroll_delta.x()
285                                        } else {
286                                            scroll_delta.y()
287                                        }
288                                    }
289                                    Axis::Vertical => scroll_delta.y(),
290                                };
291                                if !e.delta.precise() {
292                                    delta *= 20.;
293                                }
294
295                                scroll_state
296                                    .scroll_position
297                                    .set(scroll_state.scroll_position.get() - delta);
298
299                                cx.notify();
300                            } else {
301                                cx.propagate_event();
302                            }
303                        }
304                    })
305                    .on_move(|_, _: &mut V, _| { /* Capture move events */ }),
306            )
307        }
308
309        let mut child_origin = bounds.origin();
310        if let Some(scroll_state) = self.scroll_state.as_ref() {
311            let scroll_position = scroll_state.0.read(cx).scroll_position.get();
312            match self.axis {
313                Axis::Horizontal => child_origin.set_x(child_origin.x() - scroll_position),
314                Axis::Vertical => child_origin.set_y(child_origin.y() - scroll_position),
315            }
316        }
317
318        for child in &mut self.children {
319            if remaining_space > 0. {
320                if let Some(metadata) = child.metadata::<FlexParentData>() {
321                    if metadata.float {
322                        match self.axis {
323                            Axis::Horizontal => child_origin += vec2f(remaining_space, 0.0),
324                            Axis::Vertical => child_origin += vec2f(0.0, remaining_space),
325                        }
326                        remaining_space = 0.;
327                    }
328                }
329            }
330
331            // We use the child_alignment f32 to determine a point along the cross axis of the
332            // overall flex element and each child. We then align these points. So 0 would center
333            // each child relative to the overall height/width of the flex. -1 puts children at
334            // the start. 1 puts children at the end.
335            let aligned_child_origin = {
336                let cross_axis = self.axis.invert();
337                let my_center = bounds.size().along(cross_axis) / 2.;
338                let my_target = my_center + my_center * self.child_alignment;
339
340                let child_center = child.size().along(cross_axis) / 2.;
341                let child_target = child_center + child_center * self.child_alignment;
342
343                let mut aligned_child_origin = child_origin;
344                match self.axis {
345                    Axis::Horizontal => aligned_child_origin
346                        .set_y(aligned_child_origin.y() - (child_target - my_target)),
347                    Axis::Vertical => aligned_child_origin
348                        .set_x(aligned_child_origin.x() - (child_target - my_target)),
349                }
350
351                aligned_child_origin
352            };
353
354            child.paint(scene, aligned_child_origin, visible_bounds, view, cx);
355
356            match self.axis {
357                Axis::Horizontal => child_origin += vec2f(child.size().x(), 0.0),
358                Axis::Vertical => child_origin += vec2f(0.0, child.size().y()),
359            }
360        }
361
362        if overflowing {
363            scene.pop_layer();
364        }
365    }
366
367    fn rect_for_text_range(
368        &self,
369        range_utf16: Range<usize>,
370        _: RectF,
371        _: RectF,
372        _: &Self::LayoutState,
373        _: &Self::PaintState,
374        view: &V,
375        cx: &ViewContext<V>,
376    ) -> Option<RectF> {
377        self.children
378            .iter()
379            .find_map(|child| child.rect_for_text_range(range_utf16.clone(), view, cx))
380    }
381
382    fn debug(
383        &self,
384        bounds: RectF,
385        _: &Self::LayoutState,
386        _: &Self::PaintState,
387        view: &V,
388        cx: &ViewContext<V>,
389    ) -> json::Value {
390        json!({
391            "type": "Flex",
392            "bounds": bounds.to_json(),
393            "axis": self.axis.to_json(),
394            "children": self.children.iter().map(|child| child.debug(view, cx)).collect::<Vec<json::Value>>()
395        })
396    }
397}
398
399struct FlexParentData {
400    flex: Option<(f32, bool)>,
401    float: bool,
402}
403
404pub struct FlexItem<V> {
405    metadata: FlexParentData,
406    child: AnyElement<V>,
407}
408
409impl<V: 'static> FlexItem<V> {
410    pub fn new(child: impl Element<V>) -> Self {
411        FlexItem {
412            metadata: FlexParentData {
413                flex: None,
414                float: false,
415            },
416            child: child.into_any(),
417        }
418    }
419
420    pub fn flex(mut self, flex: f32, expanded: bool) -> Self {
421        self.metadata.flex = Some((flex, expanded));
422        self
423    }
424
425    pub fn float(mut self) -> Self {
426        self.metadata.float = true;
427        self
428    }
429}
430
431impl<V: 'static> Element<V> for FlexItem<V> {
432    type LayoutState = ();
433    type PaintState = ();
434
435    fn layout(
436        &mut self,
437        constraint: SizeConstraint,
438        view: &mut V,
439        cx: &mut LayoutContext<V>,
440    ) -> (Vector2F, Self::LayoutState) {
441        let size = self.child.layout(constraint, view, cx);
442        (size, ())
443    }
444
445    fn paint(
446        &mut self,
447        scene: &mut SceneBuilder,
448        bounds: RectF,
449        visible_bounds: RectF,
450        _: &mut Self::LayoutState,
451        view: &mut V,
452        cx: &mut PaintContext<V>,
453    ) -> Self::PaintState {
454        self.child
455            .paint(scene, bounds.origin(), visible_bounds, view, cx)
456    }
457
458    fn rect_for_text_range(
459        &self,
460        range_utf16: Range<usize>,
461        _: RectF,
462        _: RectF,
463        _: &Self::LayoutState,
464        _: &Self::PaintState,
465        view: &V,
466        cx: &ViewContext<V>,
467    ) -> Option<RectF> {
468        self.child.rect_for_text_range(range_utf16, view, cx)
469    }
470
471    fn metadata(&self) -> Option<&dyn Any> {
472        Some(&self.metadata)
473    }
474
475    fn debug(
476        &self,
477        _: RectF,
478        _: &Self::LayoutState,
479        _: &Self::PaintState,
480        view: &V,
481        cx: &ViewContext<V>,
482    ) -> Value {
483        json!({
484            "type": "Flexible",
485            "flex": self.metadata.flex,
486            "child": self.child.debug(view, cx)
487        })
488    }
489}