Finished fixing flex scrolls

Mikayla Maki created

Change summary

crates/gpui/src/elements/flex.rs | 115 +++++++++++++--------------------
1 file changed, 47 insertions(+), 68 deletions(-)

Detailed changes

crates/gpui/src/elements/flex.rs 🔗

@@ -4,8 +4,7 @@ use crate::{
     json::{self, ToJson, Value},
     presenter::MeasurementContext,
     Axis, DebugContext, Element, ElementBox, ElementStateHandle, Event, EventContext,
-    LayoutContext, MouseRegion, PaintContext, RenderContext, ScrollWheelEvent, SizeConstraint,
-    Vector2FExt, View,
+    LayoutContext, PaintContext, RenderContext, SizeConstraint, Vector2FExt, View,
 };
 use pathfinder_geometry::{
     rect::RectF,
@@ -13,16 +12,16 @@ use pathfinder_geometry::{
 };
 use serde_json::json;
 
-#[derive(Default, Clone, Copy)]
+#[derive(Default)]
 struct ScrollState {
-    scroll_to: Option<usize>,
-    scroll_position: f32,
+    scroll_to: Cell<Option<usize>>,
+    scroll_position: Cell<f32>,
 }
 
 pub struct Flex {
     axis: Axis,
     children: Vec<ElementBox>,
-    scroll_state: Option<(ElementStateHandle<Rc<Cell<ScrollState>>>, usize)>,
+    scroll_state: Option<(ElementStateHandle<Rc<ScrollState>>, usize)>,
 }
 
 impl Flex {
@@ -52,15 +51,9 @@ impl Flex {
         Tag: 'static,
         V: View,
     {
-        let scroll_state_handle =
-            cx.default_element_state::<Tag, Rc<Cell<ScrollState>>>(element_id);
-        let scroll_state_cell = scroll_state_handle.read(cx);
-        let mut scroll_state = scroll_state_cell.get();
-        scroll_state.scroll_to = scroll_to;
-        scroll_state_cell.set(scroll_state);
-
-        self.scroll_state = Some((scroll_state_handle, cx.handle().id()));
-
+        let scroll_state = cx.default_element_state::<Tag, Rc<ScrollState>>(element_id);
+        scroll_state.read(cx).scroll_to.set(scroll_to);
+        self.scroll_state = Some((scroll_state, cx.handle().id()));
         self
     }
 
@@ -106,38 +99,6 @@ impl Flex {
             }
         }
     }
-
-    fn handle_scroll(
-        e: ScrollWheelEvent,
-        axis: Axis,
-        scroll_state: Rc<Cell<ScrollState>>,
-        remaining_space: f32,
-    ) -> bool {
-        let precise = e.precise;
-        let delta = e.delta;
-        if remaining_space < 0. {
-            let mut delta = match axis {
-                Axis::Horizontal => {
-                    if delta.x() != 0. {
-                        delta.x()
-                    } else {
-                        delta.y()
-                    }
-                }
-                Axis::Vertical => delta.y(),
-            };
-            if !precise {
-                delta *= 20.;
-            }
-
-            let mut old_state = scroll_state.get();
-            old_state.scroll_position -= delta;
-            scroll_state.set(old_state);
-
-            return true;
-        }
-        return false;
-    }
 }
 
 impl Extend<ElementBox> for Flex {
@@ -241,8 +202,8 @@ impl Element for Flex {
 
         if let Some(scroll_state) = self.scroll_state.as_ref() {
             scroll_state.0.update(cx, |scroll_state, _| {
-                if let Some(scroll_to) = scroll_state.get().scroll_to.take() {
-                    let visible_start = scroll_state.get().scroll_position;
+                if let Some(scroll_to) = scroll_state.scroll_to.take() {
+                    let visible_start = scroll_state.scroll_position.get();
                     let visible_end = visible_start + size.along(self.axis);
                     if let Some(child) = self.children.get(scroll_to) {
                         let child_start: f32 = self.children[..scroll_to]
@@ -250,20 +211,23 @@ impl Element for Flex {
                             .map(|c| c.size().along(self.axis))
                             .sum();
                         let child_end = child_start + child.size().along(self.axis);
-
-                        let mut old_state = scroll_state.get();
                         if child_start < visible_start {
-                            old_state.scroll_position = child_start;
+                            scroll_state.scroll_position.set(child_start);
                         } else if child_end > visible_end {
-                            old_state.scroll_position = child_end - size.along(self.axis);
+                            scroll_state
+                                .scroll_position
+                                .set(child_end - size.along(self.axis));
                         }
-                        scroll_state.set(old_state);
                     }
                 }
 
-                let mut old_state = scroll_state.get();
-                old_state.scroll_position = old_state.scroll_position.min(-remaining_space).max(0.);
-                scroll_state.set(old_state);
+                scroll_state.scroll_position.set(
+                    scroll_state
+                        .scroll_position
+                        .get()
+                        .min(-remaining_space)
+                        .max(0.),
+                );
             });
         }
 
@@ -286,28 +250,43 @@ impl Element for Flex {
 
         if let Some(scroll_state) = &self.scroll_state {
             cx.scene.push_mouse_region(
-                MouseRegion::new::<Self>(scroll_state.1, 0, bounds)
+                crate::MouseRegion::new::<Self>(scroll_state.1, 0, bounds)
                     .on_scroll({
-                        let axis = self.axis;
                         let scroll_state = scroll_state.0.read(cx).clone();
+                        let axis = self.axis;
                         move |e, cx| {
-                            if Self::handle_scroll(
-                                e.platform_event,
-                                axis,
-                                scroll_state.clone(),
-                                remaining_space,
-                            ) {
+                            if remaining_space < 0. {
+                                let mut delta = match axis {
+                                    Axis::Horizontal => {
+                                        if e.delta.x() != 0. {
+                                            e.delta.x()
+                                        } else {
+                                            e.delta.y()
+                                        }
+                                    }
+                                    Axis::Vertical => e.delta.y(),
+                                };
+                                if !e.precise {
+                                    delta *= 20.;
+                                }
+
+                                scroll_state
+                                    .scroll_position
+                                    .set(scroll_state.scroll_position.get() - delta);
+
+                                cx.notify();
+                            } else {
                                 cx.propogate_event();
                             }
                         }
                     })
-                    .on_move(|_, _| { /* Eat move events so they don't propogate */ }),
-            );
+                    .on_move(|_, _| { /* Capture move events */ }),
+            )
         }
 
         let mut child_origin = bounds.origin();
         if let Some(scroll_state) = self.scroll_state.as_ref() {
-            let scroll_position = scroll_state.0.read(cx).get().scroll_position;
+            let scroll_position = scroll_state.0.read(cx).scroll_position.get();
             match self.axis {
                 Axis::Horizontal => child_origin.set_x(child_origin.x() - scroll_position),
                 Axis::Vertical => child_origin.set_y(child_origin.y() - scroll_position),