Add global tests and wrap global update functions in update call to flush effects

Keith Simmons and Antonio Scandurra created

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

crates/gpui/src/app.rs | 173 +++++++++++++++++++++++++++++++------------
1 file changed, 124 insertions(+), 49 deletions(-)

Detailed changes

crates/gpui/src/app.rs 🔗

@@ -763,7 +763,7 @@ type GlobalActionCallback = dyn FnMut(&dyn AnyAction, &mut MutableAppContext);
 type SubscriptionCallback = Box<dyn FnMut(&dyn Any, &mut MutableAppContext) -> bool>;
 type GlobalSubscriptionCallback = Box<dyn FnMut(&dyn Any, &mut MutableAppContext)>;
 type ObservationCallback = Box<dyn FnMut(&mut MutableAppContext) -> bool>;
-type GlobalObservationCallback = Box<dyn FnMut(&mut MutableAppContext) -> bool>;
+type GlobalObservationCallback = Box<dyn FnMut(&mut MutableAppContext)>;
 type ReleaseObservationCallback = Box<dyn FnMut(&dyn Any, &mut MutableAppContext)>;
 
 pub struct MutableAppContext {
@@ -1206,7 +1206,7 @@ impl MutableAppContext {
     pub fn observe_global<G, F>(&mut self, observe: F) -> Subscription
     where
         G: Any,
-        F: 'static + FnMut(&mut MutableAppContext) -> bool,
+        F: 'static + FnMut(&mut MutableAppContext),
     {
         let type_id = TypeId::of::<G>();
         let id = post_inc(&mut self.next_subscription_id);
@@ -1415,21 +1415,25 @@ impl MutableAppContext {
 
     pub fn default_global<T: 'static + Default>(&mut self) -> &T {
         let type_id = TypeId::of::<T>();
-        if !self.cx.globals.contains_key(&type_id) {
-            self.notify_global(type_id);
-        }
-        self.cx
-            .globals
-            .entry(type_id)
-            .or_insert_with(|| Box::new(T::default()))
-            .downcast_ref()
-            .unwrap()
+        self.update(|this| {
+            if !this.globals.contains_key(&type_id) {
+                this.notify_global(type_id);
+            }
+
+            this.cx
+                .globals
+                .entry(type_id)
+                .or_insert_with(|| Box::new(T::default()));
+        });
+        self.globals.get(&type_id).unwrap().downcast_ref().unwrap()
     }
 
     pub fn set_global<T: 'static>(&mut self, state: T) {
-        let type_id = TypeId::of::<T>();
-        self.cx.globals.insert(type_id, Box::new(state));
-        self.notify_global(type_id);
+        self.update(|this| {
+            let type_id = TypeId::of::<T>();
+            this.cx.globals.insert(type_id, Box::new(state));
+            this.notify_global(type_id);
+        });
     }
 
     pub fn update_default_global<T, F, U>(&mut self, update: F) -> U
@@ -1437,16 +1441,18 @@ impl MutableAppContext {
         T: 'static + Default,
         F: FnOnce(&mut T, &mut MutableAppContext) -> U,
     {
-        let type_id = TypeId::of::<T>();
-        let mut state = self
-            .cx
-            .globals
-            .remove(&type_id)
-            .unwrap_or_else(|| Box::new(T::default()));
-        let result = update(state.downcast_mut().unwrap(), self);
-        self.cx.globals.insert(type_id, state);
-        self.notify_global(type_id);
-        result
+        self.update(|this| {
+            let type_id = TypeId::of::<T>();
+            let mut state = this
+                .cx
+                .globals
+                .remove(&type_id)
+                .unwrap_or_else(|| Box::new(T::default()));
+            let result = update(state.downcast_mut().unwrap(), this);
+            this.cx.globals.insert(type_id, state);
+            this.notify_global(type_id);
+            result
+        })
     }
 
     pub fn update_global<T, F, U>(&mut self, update: F) -> U
@@ -1454,16 +1460,18 @@ impl MutableAppContext {
         T: 'static,
         F: FnOnce(&mut T, &mut MutableAppContext) -> U,
     {
-        let type_id = TypeId::of::<T>();
-        let mut state = self
-            .cx
-            .globals
-            .remove(&type_id)
-            .expect("no global has been added for this type");
-        let result = update(state.downcast_mut().unwrap(), self);
-        self.cx.globals.insert(type_id, state);
-        self.notify_global(type_id);
-        result
+        self.update(|this| {
+            let type_id = TypeId::of::<T>();
+            let mut state = this
+                .cx
+                .globals
+                .remove(&type_id)
+                .expect("no global has been added for this type");
+            let result = update(state.downcast_mut().unwrap(), this);
+            this.cx.globals.insert(type_id, state);
+            this.notify_global(type_id);
+            result
+        })
     }
 
     pub fn add_model<T, F>(&mut self, build_model: F) -> ModelHandle<T>
@@ -2056,21 +2064,19 @@ impl MutableAppContext {
             if self.cx.globals.contains_key(&observed_type_id) {
                 for (id, callback) in callbacks {
                     if let Some(mut callback) = callback {
-                        let alive = callback(self);
-                        if alive {
-                            match self
-                                .global_observations
-                                .lock()
-                                .entry(observed_type_id)
-                                .or_default()
-                                .entry(id)
-                            {
-                                collections::btree_map::Entry::Vacant(entry) => {
-                                    entry.insert(Some(callback));
-                                }
-                                collections::btree_map::Entry::Occupied(entry) => {
-                                    entry.remove();
-                                }
+                        callback(self);
+                        match self
+                            .global_observations
+                            .lock()
+                            .entry(observed_type_id)
+                            .or_default()
+                            .entry(id)
+                        {
+                            collections::btree_map::Entry::Vacant(entry) => {
+                                entry.insert(Some(callback));
+                            }
+                            collections::btree_map::Entry::Occupied(entry) => {
+                                entry.remove();
                             }
                         }
                     }
@@ -5205,6 +5211,61 @@ mod tests {
         );
     }
 
+    #[crate::test(self)]
+    fn test_global(cx: &mut MutableAppContext) {
+        type Global = usize;
+
+        let observation_count = Rc::new(RefCell::new(0));
+        let subscription = cx.observe_global::<Global, _>({
+            let observation_count = observation_count.clone();
+            move |_| {
+                *observation_count.borrow_mut() += 1;
+            }
+        });
+
+        assert!(!cx.has_global::<Global>());
+        assert_eq!(cx.default_global::<Global>(), &0);
+        assert_eq!(*observation_count.borrow(), 1);
+        assert!(cx.has_global::<Global>());
+        assert_eq!(
+            cx.update_global::<Global, _, _>(|global, _| {
+                *global = 1;
+                "Update Result"
+            }),
+            "Update Result"
+        );
+        assert_eq!(*observation_count.borrow(), 2);
+        assert_eq!(cx.global::<Global>(), &1);
+
+        drop(subscription);
+        cx.update_global::<Global, _, _>(|global, _| {
+            *global = 2;
+        });
+        assert_eq!(*observation_count.borrow(), 2);
+
+        type OtherGlobal = f32;
+
+        let observation_count = Rc::new(RefCell::new(0));
+        cx.observe_global::<OtherGlobal, _>({
+            let observation_count = observation_count.clone();
+            move |_| {
+                *observation_count.borrow_mut() += 1;
+            }
+        })
+        .detach();
+
+        assert_eq!(
+            cx.update_default_global::<OtherGlobal, _, _>(|global, _| {
+                assert_eq!(global, &0.0);
+                *global = 2.0;
+                "Default update result"
+            }),
+            "Default update result"
+        );
+        assert_eq!(cx.global::<OtherGlobal>(), &2.0);
+        assert_eq!(*observation_count.borrow(), 1);
+    }
+
     #[crate::test(self)]
     fn test_dropping_subscribers(cx: &mut MutableAppContext) {
         struct View;
@@ -5556,6 +5617,20 @@ mod tests {
         assert_eq!(*observation_count.borrow(), 1);
 
         // Global Observation
+        let observation_count = Rc::new(RefCell::new(0));
+        let subscription = Rc::new(RefCell::new(None));
+        *subscription.borrow_mut() = Some(cx.observe_global::<(), _>({
+            let observation_count = observation_count.clone();
+            let subscription = subscription.clone();
+            move |_| {
+                subscription.borrow_mut().take();
+                *observation_count.borrow_mut() += 1;
+            }
+        }));
+
+        cx.default_global::<()>();
+        cx.set_global(());
+        assert_eq!(*observation_count.borrow(), 1);
     }
 
     #[crate::test(self)]