flex.rs

  1use std::{any::Any, cell::Cell, f32::INFINITY, ops::Range, rc::Rc};
  2
  3use crate::{
  4    json::{self, ToJson, Value},
  5    AnyElement, Axis, Element, ElementStateHandle, SizeConstraint, TypeTag, Vector2FExt,
  6    ViewContext,
  7};
  8use pathfinder_geometry::{
  9    rect::RectF,
 10    vector::{vec2f, Vector2F},
 11};
 12use serde_json::json;
 13
 14struct ScrollState {
 15    scroll_to: Cell<Option<usize>>,
 16    scroll_position: Cell<f32>,
 17    type_tag: TypeTag,
 18}
 19
 20pub struct Flex<V> {
 21    axis: Axis,
 22    children: Vec<AnyElement<V>>,
 23    scroll_state: Option<(ElementStateHandle<Rc<ScrollState>>, usize)>,
 24    child_alignment: f32,
 25    spacing: f32,
 26}
 27
 28impl<V: 'static> Flex<V> {
 29    pub fn new(axis: Axis) -> Self {
 30        Self {
 31            axis,
 32            children: Default::default(),
 33            scroll_state: None,
 34            child_alignment: -1.,
 35            spacing: 0.,
 36        }
 37    }
 38
 39    pub fn row() -> Self {
 40        Self::new(Axis::Horizontal)
 41    }
 42
 43    pub fn column() -> Self {
 44        Self::new(Axis::Vertical)
 45    }
 46
 47    /// Render children centered relative to the cross-axis of the parent flex.
 48    ///
 49    /// If this is a flex row, children will be centered vertically. If this is a
 50    /// flex column, children will be centered horizontally.
 51    pub fn align_children_center(mut self) -> Self {
 52        self.child_alignment = 0.;
 53        self
 54    }
 55
 56    pub fn with_spacing(mut self, spacing: f32) -> Self {
 57        self.spacing = spacing;
 58        self
 59    }
 60
 61    pub fn scrollable<Tag>(
 62        mut self,
 63        element_id: usize,
 64        scroll_to: Option<usize>,
 65        cx: &mut ViewContext<V>,
 66    ) -> Self
 67    where
 68        Tag: 'static,
 69    {
 70        // Don't assume that this initialization is what scroll_state really is in other panes:
 71        // `element_state` is shared and there could be init races.
 72        let scroll_state = cx.element_state::<Tag, Rc<ScrollState>>(
 73            element_id,
 74            Rc::new(ScrollState {
 75                type_tag: TypeTag::new::<Tag>(),
 76                scroll_to: Default::default(),
 77                scroll_position: Default::default(),
 78            }),
 79        );
 80        // Set scroll_to separately, because the default state is already picked as `None` by other panes
 81        // by the time we start setting it here, hence update all others' state too.
 82        scroll_state.update(cx, |this, _| {
 83            this.scroll_to.set(scroll_to);
 84        });
 85        self.scroll_state = Some((scroll_state, cx.handle().id()));
 86        self
 87    }
 88
 89    pub fn is_empty(&self) -> bool {
 90        self.children.is_empty()
 91    }
 92
 93    fn layout_flex_children(
 94        &mut self,
 95        layout_expanded: bool,
 96        constraint: SizeConstraint,
 97        remaining_space: &mut f32,
 98        remaining_flex: &mut f32,
 99        cross_axis_max: &mut f32,
100        view: &mut V,
101        cx: &mut ViewContext<V>,
102    ) {
103        let cross_axis = self.axis.invert();
104        for child in self.children.iter_mut() {
105            if let Some(metadata) = child.metadata::<FlexParentData>() {
106                if let Some((flex, expanded)) = metadata.flex {
107                    if expanded != layout_expanded {
108                        continue;
109                    }
110
111                    let child_max = if *remaining_flex == 0.0 {
112                        *remaining_space
113                    } else {
114                        let space_per_flex = *remaining_space / *remaining_flex;
115                        space_per_flex * flex
116                    };
117                    let child_min = if expanded { child_max } else { 0. };
118                    let child_constraint = match self.axis {
119                        Axis::Horizontal => SizeConstraint::new(
120                            vec2f(child_min, constraint.min.y()),
121                            vec2f(child_max, constraint.max.y()),
122                        ),
123                        Axis::Vertical => SizeConstraint::new(
124                            vec2f(constraint.min.x(), child_min),
125                            vec2f(constraint.max.x(), child_max),
126                        ),
127                    };
128                    let child_size = child.layout(child_constraint, view, cx);
129                    *remaining_space -= child_size.along(self.axis);
130                    *remaining_flex -= flex;
131                    *cross_axis_max = cross_axis_max.max(child_size.along(cross_axis));
132                }
133            }
134        }
135    }
136}
137
138impl<V> Extend<AnyElement<V>> for Flex<V> {
139    fn extend<T: IntoIterator<Item = AnyElement<V>>>(&mut self, children: T) {
140        self.children.extend(children);
141    }
142}
143
144impl<V: 'static> Element<V> for Flex<V> {
145    type LayoutState = f32;
146    type PaintState = ();
147
148    fn layout(
149        &mut self,
150        constraint: SizeConstraint,
151        view: &mut V,
152        cx: &mut ViewContext<V>,
153    ) -> (Vector2F, Self::LayoutState) {
154        let mut total_flex = None;
155        let mut fixed_space = self.children.len().saturating_sub(1) as f32 * self.spacing;
156        let mut contains_float = false;
157
158        let cross_axis = self.axis.invert();
159        let mut cross_axis_max: f32 = 0.0;
160        for child in self.children.iter_mut() {
161            let metadata = child.metadata::<FlexParentData>();
162            contains_float |= metadata.map_or(false, |metadata| metadata.float);
163
164            if let Some(flex) = metadata.and_then(|metadata| metadata.flex.map(|(flex, _)| flex)) {
165                *total_flex.get_or_insert(0.) += flex;
166            } else {
167                let child_constraint = match self.axis {
168                    Axis::Horizontal => SizeConstraint::new(
169                        vec2f(0.0, constraint.min.y()),
170                        vec2f(INFINITY, constraint.max.y()),
171                    ),
172                    Axis::Vertical => SizeConstraint::new(
173                        vec2f(constraint.min.x(), 0.0),
174                        vec2f(constraint.max.x(), INFINITY),
175                    ),
176                };
177                let size = child.layout(child_constraint, view, cx);
178                fixed_space += size.along(self.axis);
179                cross_axis_max = cross_axis_max.max(size.along(cross_axis));
180            }
181        }
182
183        let mut remaining_space = constraint.max_along(self.axis) - fixed_space;
184        let mut size = if let Some(mut remaining_flex) = total_flex {
185            if remaining_space.is_infinite() {
186                panic!("flex contains flexible children but has an infinite constraint along the flex axis");
187            }
188
189            self.layout_flex_children(
190                false,
191                constraint,
192                &mut remaining_space,
193                &mut remaining_flex,
194                &mut cross_axis_max,
195                view,
196                cx,
197            );
198            self.layout_flex_children(
199                true,
200                constraint,
201                &mut remaining_space,
202                &mut remaining_flex,
203                &mut cross_axis_max,
204                view,
205                cx,
206            );
207
208            match self.axis {
209                Axis::Horizontal => vec2f(constraint.max.x() - remaining_space, cross_axis_max),
210                Axis::Vertical => vec2f(cross_axis_max, constraint.max.y() - remaining_space),
211            }
212        } else {
213            match self.axis {
214                Axis::Horizontal => vec2f(fixed_space, cross_axis_max),
215                Axis::Vertical => vec2f(cross_axis_max, fixed_space),
216            }
217        };
218
219        if contains_float {
220            match self.axis {
221                Axis::Horizontal => size.set_x(size.x().max(constraint.max.x())),
222                Axis::Vertical => size.set_y(size.y().max(constraint.max.y())),
223            }
224        }
225
226        if constraint.min.x().is_finite() {
227            size.set_x(size.x().max(constraint.min.x()));
228        }
229        if constraint.min.y().is_finite() {
230            size.set_y(size.y().max(constraint.min.y()));
231        }
232
233        if size.x() > constraint.max.x() {
234            size.set_x(constraint.max.x());
235        }
236        if size.y() > constraint.max.y() {
237            size.set_y(constraint.max.y());
238        }
239
240        if let Some(scroll_state) = self.scroll_state.as_ref() {
241            scroll_state.0.update(cx, |scroll_state, _| {
242                if let Some(scroll_to) = scroll_state.scroll_to.take() {
243                    let visible_start = scroll_state.scroll_position.get();
244                    let visible_end = visible_start + size.along(self.axis);
245                    if let Some(child) = self.children.get(scroll_to) {
246                        let child_start: f32 = self.children[..scroll_to]
247                            .iter()
248                            .map(|c| c.size().along(self.axis))
249                            .sum();
250                        let child_end = child_start + child.size().along(self.axis);
251                        if child_start < visible_start {
252                            scroll_state.scroll_position.set(child_start);
253                        } else if child_end > visible_end {
254                            scroll_state
255                                .scroll_position
256                                .set(child_end - size.along(self.axis));
257                        }
258                    }
259                }
260
261                scroll_state.scroll_position.set(
262                    scroll_state
263                        .scroll_position
264                        .get()
265                        .min(-remaining_space)
266                        .max(0.),
267                );
268            });
269        }
270
271        (size, remaining_space)
272    }
273
274    fn paint(
275        &mut self,
276        bounds: RectF,
277        visible_bounds: RectF,
278        remaining_space: &mut Self::LayoutState,
279        view: &mut V,
280        cx: &mut ViewContext<V>,
281    ) -> Self::PaintState {
282        let visible_bounds = bounds.intersection(visible_bounds).unwrap_or_default();
283
284        let mut remaining_space = *remaining_space;
285        let overflowing = remaining_space < 0.;
286        if overflowing {
287            cx.scene().push_layer(Some(visible_bounds));
288        }
289
290        if let Some((scroll_state, id)) = &self.scroll_state {
291            let scroll_state = scroll_state.read(cx).clone();
292            cx.scene().push_mouse_region(
293                crate::MouseRegion::from_handlers(
294                    scroll_state.type_tag,
295                    *id,
296                    0,
297                    bounds,
298                    Default::default(),
299                )
300                .on_scroll({
301                    let axis = self.axis;
302                    move |e, _: &mut V, cx| {
303                        if remaining_space < 0. {
304                            let scroll_delta = e.delta.raw();
305
306                            let mut delta = match axis {
307                                Axis::Horizontal => {
308                                    if scroll_delta.x().abs() >= scroll_delta.y().abs() {
309                                        scroll_delta.x()
310                                    } else {
311                                        scroll_delta.y()
312                                    }
313                                }
314                                Axis::Vertical => scroll_delta.y(),
315                            };
316                            if !e.delta.precise() {
317                                delta *= 20.;
318                            }
319
320                            scroll_state
321                                .scroll_position
322                                .set(scroll_state.scroll_position.get() - delta);
323
324                            cx.notify();
325                        } else {
326                            cx.propagate_event();
327                        }
328                    }
329                })
330                .on_move(|_, _: &mut V, _| { /* Capture move events */ }),
331            )
332        }
333
334        let mut child_origin = bounds.origin();
335        if let Some(scroll_state) = self.scroll_state.as_ref() {
336            let scroll_position = scroll_state.0.read(cx).scroll_position.get();
337            match self.axis {
338                Axis::Horizontal => child_origin.set_x(child_origin.x() - scroll_position),
339                Axis::Vertical => child_origin.set_y(child_origin.y() - scroll_position),
340            }
341        }
342
343        for child in self.children.iter_mut() {
344            if remaining_space > 0. {
345                if let Some(metadata) = child.metadata::<FlexParentData>() {
346                    if metadata.float {
347                        match self.axis {
348                            Axis::Horizontal => child_origin += vec2f(remaining_space, 0.0),
349                            Axis::Vertical => child_origin += vec2f(0.0, remaining_space),
350                        }
351                        remaining_space = 0.;
352                    }
353                }
354            }
355
356            // We use the child_alignment f32 to determine a point along the cross axis of the
357            // overall flex element and each child. We then align these points. So 0 would center
358            // each child relative to the overall height/width of the flex. -1 puts children at
359            // the start. 1 puts children at the end.
360            let aligned_child_origin = {
361                let cross_axis = self.axis.invert();
362                let my_center = bounds.size().along(cross_axis) / 2.;
363                let my_target = my_center + my_center * self.child_alignment;
364
365                let child_center = child.size().along(cross_axis) / 2.;
366                let child_target = child_center + child_center * self.child_alignment;
367
368                let mut aligned_child_origin = child_origin;
369                match self.axis {
370                    Axis::Horizontal => aligned_child_origin
371                        .set_y(aligned_child_origin.y() - (child_target - my_target)),
372                    Axis::Vertical => aligned_child_origin
373                        .set_x(aligned_child_origin.x() - (child_target - my_target)),
374                }
375
376                aligned_child_origin
377            };
378
379            child.paint(aligned_child_origin, visible_bounds, view, cx);
380
381            match self.axis {
382                Axis::Horizontal => child_origin += vec2f(child.size().x() + self.spacing, 0.0),
383                Axis::Vertical => child_origin += vec2f(0.0, child.size().y() + self.spacing),
384            }
385        }
386
387        if overflowing {
388            cx.scene().pop_layer();
389        }
390    }
391
392    fn rect_for_text_range(
393        &self,
394        range_utf16: Range<usize>,
395        _: RectF,
396        _: RectF,
397        _: &Self::LayoutState,
398        _: &Self::PaintState,
399        view: &V,
400        cx: &ViewContext<V>,
401    ) -> Option<RectF> {
402        self.children
403            .iter()
404            .find_map(|child| child.rect_for_text_range(range_utf16.clone(), view, cx))
405    }
406
407    fn debug(
408        &self,
409        bounds: RectF,
410        _: &Self::LayoutState,
411        _: &Self::PaintState,
412        view: &V,
413        cx: &ViewContext<V>,
414    ) -> json::Value {
415        json!({
416            "type": "Flex",
417            "bounds": bounds.to_json(),
418            "axis": self.axis.to_json(),
419            "children": self.children.iter().map(|child| child.debug(view, cx)).collect::<Vec<json::Value>>()
420        })
421    }
422}
423
424struct FlexParentData {
425    flex: Option<(f32, bool)>,
426    float: bool,
427}
428
429pub struct FlexItem<V> {
430    metadata: FlexParentData,
431    child: AnyElement<V>,
432}
433
434impl<V: 'static> FlexItem<V> {
435    pub fn new(child: impl Element<V>) -> Self {
436        FlexItem {
437            metadata: FlexParentData {
438                flex: None,
439                float: false,
440            },
441            child: child.into_any(),
442        }
443    }
444
445    pub fn flex(mut self, flex: f32, expanded: bool) -> Self {
446        self.metadata.flex = Some((flex, expanded));
447        self
448    }
449
450    pub fn float(mut self) -> Self {
451        self.metadata.float = true;
452        self
453    }
454}
455
456impl<V: 'static> Element<V> for FlexItem<V> {
457    type LayoutState = ();
458    type PaintState = ();
459
460    fn layout(
461        &mut self,
462        constraint: SizeConstraint,
463        view: &mut V,
464        cx: &mut ViewContext<V>,
465    ) -> (Vector2F, Self::LayoutState) {
466        let size = self.child.layout(constraint, view, cx);
467        (size, ())
468    }
469
470    fn paint(
471        &mut self,
472        bounds: RectF,
473        visible_bounds: RectF,
474        _: &mut Self::LayoutState,
475        view: &mut V,
476        cx: &mut ViewContext<V>,
477    ) -> Self::PaintState {
478        self.child.paint(bounds.origin(), visible_bounds, view, cx)
479    }
480
481    fn rect_for_text_range(
482        &self,
483        range_utf16: Range<usize>,
484        _: RectF,
485        _: RectF,
486        _: &Self::LayoutState,
487        _: &Self::PaintState,
488        view: &V,
489        cx: &ViewContext<V>,
490    ) -> Option<RectF> {
491        self.child.rect_for_text_range(range_utf16, view, cx)
492    }
493
494    fn metadata(&self) -> Option<&dyn Any> {
495        Some(&self.metadata)
496    }
497
498    fn debug(
499        &self,
500        _: RectF,
501        _: &Self::LayoutState,
502        _: &Self::PaintState,
503        view: &V,
504        cx: &ViewContext<V>,
505    ) -> Value {
506        json!({
507            "type": "Flexible",
508            "flex": self.metadata.flex,
509            "child": self.child.debug(view, cx)
510        })
511    }
512}