Add some tests for global events and fix potential bug in subscriptions when subscription is dropped inside of it's own callback

Keith Simmons and Nathan Sobo created

Co-authored-by: Nathan Sobo <nathan@zed.dev>

Change summary

crates/gpui/src/app.rs | 385 ++++++++++++++++++++++++++++++++++++++-----
1 file changed, 341 insertions(+), 44 deletions(-)

Detailed changes

crates/gpui/src/app.rs 🔗

@@ -757,9 +757,9 @@ pub struct MutableAppContext {
     next_window_id: usize,
     next_subscription_id: usize,
     frame_count: usize,
-    subscriptions: Arc<Mutex<HashMap<usize, BTreeMap<usize, SubscriptionCallback>>>>,
-    global_subscriptions: Arc<Mutex<HashMap<TypeId, BTreeMap<usize, GlobalSubscriptionCallback>>>>,
-    observations: Arc<Mutex<HashMap<usize, BTreeMap<usize, ObservationCallback>>>>,
+    subscriptions: Arc<Mutex<HashMap<usize, BTreeMap<usize, Option<SubscriptionCallback>>>>>,
+    global_subscriptions: Arc<Mutex<HashMap<TypeId, BTreeMap<usize, Option<GlobalSubscriptionCallback>>>>>,
+    observations: Arc<Mutex<HashMap<usize, BTreeMap<usize, Option<ObservationCallback>>>>>,
     release_observations: Arc<Mutex<HashMap<usize, BTreeMap<usize, ReleaseObservationCallback>>>>,
     presenters_and_platform_windows:
         HashMap<usize, (Rc<RefCell<Presenter>>, Box<dyn platform::Window>)>,
@@ -1097,10 +1097,10 @@ impl MutableAppContext {
             .or_default()
             .insert(
                 id,
-                Box::new(move |payload, cx| {
+                Some(Box::new(move |payload, cx| {
                     let payload = payload.downcast_ref().expect("downcast is type safe");
                     callback(payload, cx)
-                }),
+                })),
             );
         Subscription::GlobalSubscription {
             id,
@@ -1137,14 +1137,14 @@ impl MutableAppContext {
             .or_default()
             .insert(
                 id,
-                Box::new(move |payload, cx| {
+                Some(Box::new(move |payload, cx| {
                     if let Some(emitter) = H::upgrade_from(&emitter, cx.as_ref()) {
                         let payload = payload.downcast_ref().expect("downcast is type safe");
                         callback(emitter, payload, cx)
                     } else {
                         false
                     }
-                }),
+                })),
             );
         Subscription::Subscription {
             id,
@@ -1168,13 +1168,13 @@ impl MutableAppContext {
             .or_default()
             .insert(
                 id,
-                Box::new(move |cx| {
+                Some(Box::new(move |cx| {
                     if let Some(observed) = H::upgrade_from(&observed, cx) {
                         callback(observed, cx)
                     } else {
                         false
                     }
-                }),
+                })),
             );
         Subscription::Observation {
             id,
@@ -1722,14 +1722,24 @@ impl MutableAppContext {
     fn emit_event(&mut self, entity_id: usize, payload: Box<dyn Any>) {
         let callbacks = self.subscriptions.lock().remove(&entity_id);
         if let Some(callbacks) = callbacks {
-            for (id, mut callback) in callbacks {
-                let alive = callback(payload.as_ref(), self);
-                if alive {
-                    self.subscriptions
-                        .lock()
-                        .entry(entity_id)
-                        .or_default()
-                        .insert(id, callback);
+            for (id, callback) in callbacks {
+                if let Some(mut callback) = callback {
+                    let alive = callback(payload.as_ref(), self);
+                    if alive {
+                        match self.subscriptions
+                            .lock()
+                            .entry(entity_id)
+                            .or_default()
+                            .entry(id)
+                        {
+                            collections::btree_map::Entry::Vacant(entry) => {
+                                entry.insert(Some(callback));
+                            },
+                            collections::btree_map::Entry::Occupied(entry) => {
+                                entry.remove();
+                            },
+                        }
+                    }
                 }
             }
         }
@@ -1739,8 +1749,23 @@ impl MutableAppContext {
         let type_id = (&*payload).type_id();
         let callbacks = self.global_subscriptions.lock().remove(&type_id);
         if let Some(callbacks) = callbacks {
-            for (_, mut callback) in callbacks {
-                callback(payload.as_ref(), self)
+            for (id, callback) in callbacks {
+                if let Some(mut callback) = callback {
+                    callback(payload.as_ref(), self);
+                    match self.global_subscriptions
+                        .lock()
+                        .entry(type_id)
+                        .or_default()
+                        .entry(id) 
+                    {
+                        collections::btree_map::Entry::Vacant(entry) => {
+                            entry.insert(Some(callback));
+                        },
+                        collections::btree_map::Entry::Occupied(entry) => {
+                            entry.remove();
+                        },
+                    }
+                }
             }
         }
     }
@@ -1749,14 +1774,24 @@ impl MutableAppContext {
         let callbacks = self.observations.lock().remove(&observed_id);
         if let Some(callbacks) = callbacks {
             if self.cx.models.contains_key(&observed_id) {
-                for (id, mut callback) in callbacks {
-                    let alive = callback(self);
-                    if alive {
-                        self.observations
-                            .lock()
-                            .entry(observed_id)
-                            .or_default()
-                            .insert(id, callback);
+                for (id, callback) in callbacks {
+                    if let Some(mut callback) = callback {
+                        let alive = callback(self);
+                        if alive {
+                            match self.observations
+                                .lock()
+                                .entry(observed_id)
+                                .or_default()
+                                .entry(id) 
+                            {
+                                collections::btree_map::Entry::Vacant(entry) => {
+                                    entry.insert(Some(callback));
+                                },
+                                collections::btree_map::Entry::Occupied(entry) => {
+                                    entry.remove();
+                                },
+                            }
+                        }
                     }
                 }
             }
@@ -1779,14 +1814,24 @@ impl MutableAppContext {
                 .views
                 .contains_key(&(observed_window_id, observed_view_id))
             {
-                for (id, mut callback) in callbacks {
-                    let alive = callback(self);
-                    if alive {
-                        self.observations
-                            .lock()
-                            .entry(observed_view_id)
-                            .or_default()
-                            .insert(id, callback);
+                for (id, callback) in callbacks {
+                    if let Some(mut callback) = callback {
+                        let alive = callback(self);
+                        if alive {
+                            match self.observations
+                                .lock()
+                                .entry(observed_view_id)
+                                .or_default()
+                                .entry(id) 
+                            {
+                                collections::btree_map::Entry::Vacant(entry) => {
+                                    entry.insert(Some(callback));
+                                },
+                                collections::btree_map::Entry::Occupied(entry) => {
+                                    entry.remove();
+                                },
+                            }
+                        }
                     }
                 }
             }
@@ -3812,18 +3857,18 @@ pub enum Subscription {
     Subscription {
         id: usize,
         entity_id: usize,
-        subscriptions: Option<Weak<Mutex<HashMap<usize, BTreeMap<usize, SubscriptionCallback>>>>>,
+        subscriptions: Option<Weak<Mutex<HashMap<usize, BTreeMap<usize, Option<SubscriptionCallback>>>>>>,
     },
     GlobalSubscription {
         id: usize,
         type_id: TypeId,
         subscriptions:
-            Option<Weak<Mutex<HashMap<TypeId, BTreeMap<usize, GlobalSubscriptionCallback>>>>>,
+            Option<Weak<Mutex<HashMap<TypeId, BTreeMap<usize, Option<GlobalSubscriptionCallback>>>>>>,
     },
     Observation {
         id: usize,
         entity_id: usize,
-        observations: Option<Weak<Mutex<HashMap<usize, BTreeMap<usize, ObservationCallback>>>>>,
+        observations: Option<Weak<Mutex<HashMap<usize, BTreeMap<usize, Option<ObservationCallback>>>>>>,
     },
     ReleaseObservation {
         id: usize,
@@ -3861,8 +3906,18 @@ impl Drop for Subscription {
                 subscriptions,
             } => {
                 if let Some(subscriptions) = subscriptions.as_ref().and_then(Weak::upgrade) {
-                    if let Some(subscriptions) = subscriptions.lock().get_mut(entity_id) {
-                        subscriptions.remove(id);
+                    match subscriptions
+                        .lock()
+                        .entry(*entity_id)
+                        .or_default()
+                        .entry(*id)
+                    {
+                        collections::btree_map::Entry::Vacant(entry) => {
+                            entry.insert(None);
+                        },
+                        collections::btree_map::Entry::Occupied(entry) => {
+                            entry.remove();
+                        },
                     }
                 }
             }
@@ -3872,8 +3927,18 @@ impl Drop for Subscription {
                 subscriptions,
             } => {
                 if let Some(subscriptions) = subscriptions.as_ref().and_then(Weak::upgrade) {
-                    if let Some(subscriptions) = subscriptions.lock().get_mut(type_id) {
-                        subscriptions.remove(id);
+                    match subscriptions
+                        .lock()
+                        .entry(*type_id)
+                        .or_default()
+                        .entry(*id)
+                    {
+                        collections::btree_map::Entry::Vacant(entry) => {
+                            entry.insert(None);
+                        },
+                        collections::btree_map::Entry::Occupied(entry) => {
+                            entry.remove();
+                        },
                     }
                 }
             }
@@ -3883,8 +3948,18 @@ impl Drop for Subscription {
                 observations,
             } => {
                 if let Some(observations) = observations.as_ref().and_then(Weak::upgrade) {
-                    if let Some(observations) = observations.lock().get_mut(entity_id) {
-                        observations.remove(id);
+                    match observations
+                        .lock()
+                        .entry(*entity_id)
+                        .or_default()
+                        .entry(*id)
+                    {
+                        collections::btree_map::Entry::Vacant(entry) => {
+                            entry.insert(None);
+                        },
+                        collections::btree_map::Entry::Occupied(entry) => {
+                            entry.remove();
+                        },
                     }
                 }
             }
@@ -4464,6 +4539,96 @@ mod tests {
         assert_eq!(handle_1.read(cx).events, vec![7, 5, 10, 9]);
     }
 
+    #[crate::test(self)]
+    fn test_global_events(cx: &mut MutableAppContext) {
+        #[derive(Clone, Debug, Eq, PartialEq)]
+        struct GlobalEvent(u64);
+
+        let events = Rc::new(RefCell::new(Vec::new()));
+        let first_subscription;
+        let second_subscription;
+
+        {
+            let events = events.clone();
+            first_subscription = cx.subscribe_global(move |e: &GlobalEvent, _| {
+                events.borrow_mut().push(("First", e.clone()));
+            });
+        }
+
+        {
+            let events = events.clone();
+            second_subscription = cx.subscribe_global(move |e: &GlobalEvent, _| {
+                events.borrow_mut().push(("Second", e.clone()));
+            });
+        }
+
+        cx.update(|cx| {
+            cx.emit_global(GlobalEvent(1));
+            cx.emit_global(GlobalEvent(2));
+        });
+
+        drop(first_subscription);
+
+        cx.update(|cx| {
+            cx.emit_global(GlobalEvent(3));
+        });
+
+        drop(second_subscription);
+
+        cx.update(|cx| {
+            cx.emit_global(GlobalEvent(4));
+        });
+
+        assert_eq!(
+            &*events.borrow(),
+            &[
+                ("First", GlobalEvent(1)),
+                ("Second", GlobalEvent(1)),
+                ("First", GlobalEvent(2)),
+                ("Second", GlobalEvent(2)),
+                ("Second", GlobalEvent(3)),
+            ]
+        );
+    }
+
+    #[crate::test(self)]
+    fn test_global_nested_events(cx: &mut MutableAppContext) {
+        #[derive(Clone, Debug, Eq, PartialEq)]
+        struct GlobalEvent(u64);
+
+        let events = Rc::new(RefCell::new(Vec::new()));
+
+        {
+            let events = events.clone();
+            cx.subscribe_global(move |e: &GlobalEvent, cx| {
+                events.borrow_mut().push(("Outer", e.clone()));
+
+                let events = events.clone();
+                cx.subscribe_global(move |e: &GlobalEvent, _| {
+                    events.borrow_mut().push(("Inner", e.clone()));
+                }).detach();
+            }).detach();
+        }
+
+        cx.update(|cx| {
+            cx.emit_global(GlobalEvent(1));
+            cx.emit_global(GlobalEvent(2));
+            cx.emit_global(GlobalEvent(3));
+        });
+
+        assert_eq!(
+            &*events.borrow(),
+            &[
+                ("Outer", GlobalEvent(1)),
+                ("Outer", GlobalEvent(2)),
+                ("Inner", GlobalEvent(2)),
+                ("Outer", GlobalEvent(3)),
+                ("Inner", GlobalEvent(3)),
+                ("Inner", GlobalEvent(3)),
+            ]
+        );
+    }
+
     #[crate::test(self)]
     fn test_dropping_subscribers(cx: &mut MutableAppContext) {
         struct View;
@@ -4602,6 +4767,138 @@ mod tests {
         observed_model.update(cx, |_, cx| cx.notify());
     }
 
+    #[crate::test(self)]
+    fn test_dropping_subscriptions_during_callback(cx: &mut MutableAppContext) {
+        struct Model;
+
+        impl Entity for Model {
+            type Event = u64;
+        }
+
+        // Events
+        let observing_model = cx.add_model(|_| Model);
+        let observed_model = cx.add_model(|_| Model);
+
+        let events = Rc::new(RefCell::new(Vec::new()));
+
+        observing_model.update(cx, |_, cx| {
+            let events = events.clone();
+            let subscription = Rc::new(RefCell::new(None));
+            *subscription.borrow_mut() = Some(cx.subscribe(&observed_model, {
+                let subscription = subscription.clone();
+                move |_, _, e, _| {
+                    subscription.borrow_mut().take();
+                    events.borrow_mut().push(e.clone());
+                }
+            }));
+        });
+
+        observed_model.update(cx, |_, cx| {
+            cx.emit(1);
+            cx.emit(2);
+        });
+
+        assert_eq!(*events.borrow(), [1]);
+
+
+        // Global Events
+        #[derive(Clone, Debug, Eq, PartialEq)]
+        struct GlobalEvent(u64);
+
+        let events = Rc::new(RefCell::new(Vec::new()));
+
+        {
+            let events = events.clone();
+            let subscription = Rc::new(RefCell::new(None));
+            *subscription.borrow_mut() = Some(cx.subscribe_global({
+                let subscription = subscription.clone();
+                move |e: &GlobalEvent, _| {
+                    subscription.borrow_mut().take();
+                    events.borrow_mut().push(e.clone());
+                }
+            }));
+        }
+
+        cx.update(|cx| {
+            cx.emit_global(GlobalEvent(1));
+            cx.emit_global(GlobalEvent(2));
+        });
+
+        assert_eq!(*events.borrow(), [GlobalEvent(1)]);
+
+        // Model Observation
+        let observing_model = cx.add_model(|_| Model);
+        let observed_model = cx.add_model(|_| Model);
+
+        let observation_count = Rc::new(RefCell::new(0));
+
+        observing_model.update(cx, |_, cx| {
+            let observation_count = observation_count.clone();
+            let subscription = Rc::new(RefCell::new(None));
+            *subscription.borrow_mut() = Some(cx.observe(&observed_model, {
+                let subscription = subscription.clone();
+                move |_, _, _| {
+                    subscription.borrow_mut().take();
+                    *observation_count.borrow_mut() += 1;
+                }
+            }));
+        });
+
+        observed_model.update(cx, |_, cx| {
+            cx.notify();
+        });
+
+        observed_model.update(cx, |_, cx| {
+            cx.notify();
+        });
+
+        assert_eq!(*observation_count.borrow(), 1);
+
+        // View Observation
+        struct View;
+
+        impl Entity for View {
+            type Event = ();
+        }
+
+        impl super::View for View {
+            fn render(&mut self, _: &mut RenderContext<Self>) -> ElementBox {
+                Empty::new().boxed()
+            }
+
+            fn ui_name() -> &'static str {
+                "View"
+            }
+        }
+
+        let (window_id, _) = cx.add_window(Default::default(), |_| View);
+        let observing_view = cx.add_view(window_id, |_| View);
+        let observed_view = cx.add_view(window_id, |_| View);
+
+        let observation_count = Rc::new(RefCell::new(0));
+        observing_view.update(cx, |_, cx| {
+            let observation_count = observation_count.clone();
+            let subscription = Rc::new(RefCell::new(None));
+            *subscription.borrow_mut() = Some(cx.observe(&observed_view, {
+                let subscription = subscription.clone();
+                move |_, _, _| {
+                    subscription.borrow_mut().take();
+                    *observation_count.borrow_mut() += 1;
+                }
+            }));
+        });
+
+        observed_view.update(cx, |_, cx| {
+            cx.notify();
+        });
+
+        observed_view.update(cx, |_, cx| {
+            cx.notify();
+        });
+
+        assert_eq!(*observation_count.borrow(), 1);
+    }
+
     #[crate::test(self)]
     fn test_focus(cx: &mut MutableAppContext) {
         struct View {