1use std::{any::Any, cell::Cell, f32::INFINITY, ops::Range, rc::Rc};
2
3use crate::{
4 json::{self, ToJson, Value},
5 AnyElement, Axis, Element, ElementStateHandle, SizeConstraint, Vector2FExt, ViewContext,
6};
7use pathfinder_geometry::{
8 rect::RectF,
9 vector::{vec2f, Vector2F},
10};
11use serde_json::json;
12
13#[derive(Default)]
14struct ScrollState {
15 scroll_to: Cell<Option<usize>>,
16 scroll_position: Cell<f32>,
17}
18
19pub struct Flex<V> {
20 axis: Axis,
21 children: Vec<AnyElement<V>>,
22 scroll_state: Option<(ElementStateHandle<Rc<ScrollState>>, usize)>,
23 child_alignment: f32,
24 spacing: f32,
25}
26
27impl<V: 'static> Flex<V> {
28 pub fn new(axis: Axis) -> Self {
29 Self {
30 axis,
31 children: Default::default(),
32 scroll_state: None,
33 child_alignment: -1.,
34 spacing: 0.,
35 }
36 }
37
38 pub fn row() -> Self {
39 Self::new(Axis::Horizontal)
40 }
41
42 pub fn column() -> Self {
43 Self::new(Axis::Vertical)
44 }
45
46 /// Render children centered relative to the cross-axis of the parent flex.
47 ///
48 /// If this is a flex row, children will be centered vertically. If this is a
49 /// flex column, children will be centered horizontally.
50 pub fn align_children_center(mut self) -> Self {
51 self.child_alignment = 0.;
52 self
53 }
54
55 pub fn with_spacing(mut self, spacing: f32) -> Self {
56 self.spacing = spacing;
57 self
58 }
59
60 pub fn scrollable<Tag>(
61 mut self,
62 element_id: usize,
63 scroll_to: Option<usize>,
64 cx: &mut ViewContext<V>,
65 ) -> Self
66 where
67 Tag: 'static,
68 {
69 let scroll_state = cx.default_element_state::<Tag, Rc<ScrollState>>(element_id);
70 scroll_state.read(cx).scroll_to.set(scroll_to);
71 self.scroll_state = Some((scroll_state, cx.handle().id()));
72 self
73 }
74
75 pub fn is_empty(&self) -> bool {
76 self.children.is_empty()
77 }
78
79 fn layout_flex_children(
80 &mut self,
81 layout_expanded: bool,
82 constraint: SizeConstraint,
83 remaining_space: &mut f32,
84 remaining_flex: &mut f32,
85 cross_axis_max: &mut f32,
86 view: &mut V,
87 cx: &mut ViewContext<V>,
88 ) {
89 let cross_axis = self.axis.invert();
90 for child in self.children.iter_mut() {
91 if let Some(metadata) = child.metadata::<FlexParentData>() {
92 if let Some((flex, expanded)) = metadata.flex {
93 if expanded != layout_expanded {
94 continue;
95 }
96
97 let child_max = if *remaining_flex == 0.0 {
98 *remaining_space
99 } else {
100 let space_per_flex = *remaining_space / *remaining_flex;
101 space_per_flex * flex
102 };
103 let child_min = if expanded { child_max } else { 0. };
104 let child_constraint = match self.axis {
105 Axis::Horizontal => SizeConstraint::new(
106 vec2f(child_min, constraint.min.y()),
107 vec2f(child_max, constraint.max.y()),
108 ),
109 Axis::Vertical => SizeConstraint::new(
110 vec2f(constraint.min.x(), child_min),
111 vec2f(constraint.max.x(), child_max),
112 ),
113 };
114 let child_size = child.layout(child_constraint, view, cx);
115 *remaining_space -= child_size.along(self.axis);
116 *remaining_flex -= flex;
117 *cross_axis_max = cross_axis_max.max(child_size.along(cross_axis));
118 }
119 }
120 }
121 }
122}
123
124impl<V> Extend<AnyElement<V>> for Flex<V> {
125 fn extend<T: IntoIterator<Item = AnyElement<V>>>(&mut self, children: T) {
126 self.children.extend(children);
127 }
128}
129
130impl<V: 'static> Element<V> for Flex<V> {
131 type LayoutState = f32;
132 type PaintState = ();
133
134 fn layout(
135 &mut self,
136 constraint: SizeConstraint,
137 view: &mut V,
138 cx: &mut ViewContext<V>,
139 ) -> (Vector2F, Self::LayoutState) {
140 let mut total_flex = None;
141 let mut fixed_space = self.children.len().saturating_sub(1) as f32 * self.spacing;
142 let mut contains_float = false;
143
144 let cross_axis = self.axis.invert();
145 let mut cross_axis_max: f32 = 0.0;
146 for child in self.children.iter_mut() {
147 let metadata = child.metadata::<FlexParentData>();
148 contains_float |= metadata.map_or(false, |metadata| metadata.float);
149
150 if let Some(flex) = metadata.and_then(|metadata| metadata.flex.map(|(flex, _)| flex)) {
151 *total_flex.get_or_insert(0.) += flex;
152 } else {
153 let child_constraint = match self.axis {
154 Axis::Horizontal => SizeConstraint::new(
155 vec2f(0.0, constraint.min.y()),
156 vec2f(INFINITY, constraint.max.y()),
157 ),
158 Axis::Vertical => SizeConstraint::new(
159 vec2f(constraint.min.x(), 0.0),
160 vec2f(constraint.max.x(), INFINITY),
161 ),
162 };
163 let size = child.layout(child_constraint, view, cx);
164 fixed_space += size.along(self.axis);
165 cross_axis_max = cross_axis_max.max(size.along(cross_axis));
166 }
167 }
168
169 let mut remaining_space = constraint.max_along(self.axis) - fixed_space;
170 let mut size = if let Some(mut remaining_flex) = total_flex {
171 if remaining_space.is_infinite() {
172 panic!("flex contains flexible children but has an infinite constraint along the flex axis");
173 }
174
175 self.layout_flex_children(
176 false,
177 constraint,
178 &mut remaining_space,
179 &mut remaining_flex,
180 &mut cross_axis_max,
181 view,
182 cx,
183 );
184 self.layout_flex_children(
185 true,
186 constraint,
187 &mut remaining_space,
188 &mut remaining_flex,
189 &mut cross_axis_max,
190 view,
191 cx,
192 );
193
194 match self.axis {
195 Axis::Horizontal => vec2f(constraint.max.x() - remaining_space, cross_axis_max),
196 Axis::Vertical => vec2f(cross_axis_max, constraint.max.y() - remaining_space),
197 }
198 } else {
199 match self.axis {
200 Axis::Horizontal => vec2f(fixed_space, cross_axis_max),
201 Axis::Vertical => vec2f(cross_axis_max, fixed_space),
202 }
203 };
204
205 if contains_float {
206 match self.axis {
207 Axis::Horizontal => size.set_x(size.x().max(constraint.max.x())),
208 Axis::Vertical => size.set_y(size.y().max(constraint.max.y())),
209 }
210 }
211
212 if constraint.min.x().is_finite() {
213 size.set_x(size.x().max(constraint.min.x()));
214 }
215 if constraint.min.y().is_finite() {
216 size.set_y(size.y().max(constraint.min.y()));
217 }
218
219 if size.x() > constraint.max.x() {
220 size.set_x(constraint.max.x());
221 }
222 if size.y() > constraint.max.y() {
223 size.set_y(constraint.max.y());
224 }
225
226 if let Some(scroll_state) = self.scroll_state.as_ref() {
227 scroll_state.0.update(cx, |scroll_state, _| {
228 if let Some(scroll_to) = scroll_state.scroll_to.take() {
229 let visible_start = scroll_state.scroll_position.get();
230 let visible_end = visible_start + size.along(self.axis);
231 if let Some(child) = self.children.get(scroll_to) {
232 let child_start: f32 = self.children[..scroll_to]
233 .iter()
234 .map(|c| c.size().along(self.axis))
235 .sum();
236 let child_end = child_start + child.size().along(self.axis);
237 if child_start < visible_start {
238 scroll_state.scroll_position.set(child_start);
239 } else if child_end > visible_end {
240 scroll_state
241 .scroll_position
242 .set(child_end - size.along(self.axis));
243 }
244 }
245 }
246
247 scroll_state.scroll_position.set(
248 scroll_state
249 .scroll_position
250 .get()
251 .min(-remaining_space)
252 .max(0.),
253 );
254 });
255 }
256
257 (size, remaining_space)
258 }
259
260 fn paint(
261 &mut self,
262 bounds: RectF,
263 visible_bounds: RectF,
264 remaining_space: &mut Self::LayoutState,
265 view: &mut V,
266 cx: &mut ViewContext<V>,
267 ) -> Self::PaintState {
268 let visible_bounds = bounds.intersection(visible_bounds).unwrap_or_default();
269
270 let mut remaining_space = *remaining_space;
271 let overflowing = remaining_space < 0.;
272 if overflowing {
273 cx.scene().push_layer(Some(visible_bounds));
274 }
275
276 if let Some((scroll_state, id)) = &self.scroll_state {
277 let scroll_state = scroll_state.read(cx).clone();
278 cx.scene().push_mouse_region(
279 crate::MouseRegion::new::<Self>(*id, 0, bounds)
280 .on_scroll({
281 let axis = self.axis;
282 move |e, _: &mut V, cx| {
283 if remaining_space < 0. {
284 let scroll_delta = e.delta.raw();
285
286 let mut delta = match axis {
287 Axis::Horizontal => {
288 if scroll_delta.x().abs() >= scroll_delta.y().abs() {
289 scroll_delta.x()
290 } else {
291 scroll_delta.y()
292 }
293 }
294 Axis::Vertical => scroll_delta.y(),
295 };
296 if !e.delta.precise() {
297 delta *= 20.;
298 }
299
300 scroll_state
301 .scroll_position
302 .set(scroll_state.scroll_position.get() - delta);
303
304 cx.notify();
305 } else {
306 cx.propagate_event();
307 }
308 }
309 })
310 .on_move(|_, _: &mut V, _| { /* Capture move events */ }),
311 )
312 }
313
314 let mut child_origin = bounds.origin();
315 if let Some(scroll_state) = self.scroll_state.as_ref() {
316 let scroll_position = scroll_state.0.read(cx).scroll_position.get();
317 match self.axis {
318 Axis::Horizontal => child_origin.set_x(child_origin.x() - scroll_position),
319 Axis::Vertical => child_origin.set_y(child_origin.y() - scroll_position),
320 }
321 }
322
323 for child in self.children.iter_mut() {
324 if remaining_space > 0. {
325 if let Some(metadata) = child.metadata::<FlexParentData>() {
326 if metadata.float {
327 match self.axis {
328 Axis::Horizontal => child_origin += vec2f(remaining_space, 0.0),
329 Axis::Vertical => child_origin += vec2f(0.0, remaining_space),
330 }
331 remaining_space = 0.;
332 }
333 }
334 }
335
336 // We use the child_alignment f32 to determine a point along the cross axis of the
337 // overall flex element and each child. We then align these points. So 0 would center
338 // each child relative to the overall height/width of the flex. -1 puts children at
339 // the start. 1 puts children at the end.
340 let aligned_child_origin = {
341 let cross_axis = self.axis.invert();
342 let my_center = bounds.size().along(cross_axis) / 2.;
343 let my_target = my_center + my_center * self.child_alignment;
344
345 let child_center = child.size().along(cross_axis) / 2.;
346 let child_target = child_center + child_center * self.child_alignment;
347
348 let mut aligned_child_origin = child_origin;
349 match self.axis {
350 Axis::Horizontal => aligned_child_origin
351 .set_y(aligned_child_origin.y() - (child_target - my_target)),
352 Axis::Vertical => aligned_child_origin
353 .set_x(aligned_child_origin.x() - (child_target - my_target)),
354 }
355
356 aligned_child_origin
357 };
358
359 child.paint(aligned_child_origin, visible_bounds, view, cx);
360
361 match self.axis {
362 Axis::Horizontal => child_origin += vec2f(child.size().x() + self.spacing, 0.0),
363 Axis::Vertical => child_origin += vec2f(0.0, child.size().y() + self.spacing),
364 }
365 }
366
367 if overflowing {
368 cx.scene().pop_layer();
369 }
370 }
371
372 fn rect_for_text_range(
373 &self,
374 range_utf16: Range<usize>,
375 _: RectF,
376 _: RectF,
377 _: &Self::LayoutState,
378 _: &Self::PaintState,
379 view: &V,
380 cx: &ViewContext<V>,
381 ) -> Option<RectF> {
382 self.children
383 .iter()
384 .find_map(|child| child.rect_for_text_range(range_utf16.clone(), view, cx))
385 }
386
387 fn debug(
388 &self,
389 bounds: RectF,
390 _: &Self::LayoutState,
391 _: &Self::PaintState,
392 view: &V,
393 cx: &ViewContext<V>,
394 ) -> json::Value {
395 json!({
396 "type": "Flex",
397 "bounds": bounds.to_json(),
398 "axis": self.axis.to_json(),
399 "children": self.children.iter().map(|child| child.debug(view, cx)).collect::<Vec<json::Value>>()
400 })
401 }
402}
403
404struct FlexParentData {
405 flex: Option<(f32, bool)>,
406 float: bool,
407}
408
409pub struct FlexItem<V> {
410 metadata: FlexParentData,
411 child: AnyElement<V>,
412}
413
414impl<V: 'static> FlexItem<V> {
415 pub fn new(child: impl Element<V>) -> Self {
416 FlexItem {
417 metadata: FlexParentData {
418 flex: None,
419 float: false,
420 },
421 child: child.into_any(),
422 }
423 }
424
425 pub fn flex(mut self, flex: f32, expanded: bool) -> Self {
426 self.metadata.flex = Some((flex, expanded));
427 self
428 }
429
430 pub fn float(mut self) -> Self {
431 self.metadata.float = true;
432 self
433 }
434}
435
436impl<V: 'static> Element<V> for FlexItem<V> {
437 type LayoutState = ();
438 type PaintState = ();
439
440 fn layout(
441 &mut self,
442 constraint: SizeConstraint,
443 view: &mut V,
444 cx: &mut ViewContext<V>,
445 ) -> (Vector2F, Self::LayoutState) {
446 let size = self.child.layout(constraint, view, cx);
447 (size, ())
448 }
449
450 fn paint(
451 &mut self,
452 bounds: RectF,
453 visible_bounds: RectF,
454 _: &mut Self::LayoutState,
455 view: &mut V,
456 cx: &mut ViewContext<V>,
457 ) -> Self::PaintState {
458 self.child.paint(bounds.origin(), visible_bounds, view, cx)
459 }
460
461 fn rect_for_text_range(
462 &self,
463 range_utf16: Range<usize>,
464 _: RectF,
465 _: RectF,
466 _: &Self::LayoutState,
467 _: &Self::PaintState,
468 view: &V,
469 cx: &ViewContext<V>,
470 ) -> Option<RectF> {
471 self.child.rect_for_text_range(range_utf16, view, cx)
472 }
473
474 fn metadata(&self) -> Option<&dyn Any> {
475 Some(&self.metadata)
476 }
477
478 fn debug(
479 &self,
480 _: RectF,
481 _: &Self::LayoutState,
482 _: &Self::PaintState,
483 view: &V,
484 cx: &ViewContext<V>,
485 ) -> Value {
486 json!({
487 "type": "Flexible",
488 "flex": self.metadata.flex,
489 "child": self.child.debug(view, cx)
490 })
491 }
492}