Simplify state associated with observations

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

gpui/src/app.rs | 293 +++++++++++++++++++++++++++++++++-----------------
1 file changed, 192 insertions(+), 101 deletions(-)

Detailed changes

gpui/src/app.rs 🔗

@@ -650,8 +650,7 @@ pub struct MutableAppContext {
     next_entity_id: usize,
     next_window_id: usize,
     subscriptions: HashMap<usize, Vec<Subscription>>,
-    model_observations: HashMap<usize, Vec<ModelObservation>>,
-    view_observations: HashMap<usize, Vec<ViewObservation>>,
+    observations: HashMap<usize, Vec<Observation>>,
     presenters_and_platform_windows:
         HashMap<usize, (Rc<RefCell<Presenter>>, Box<dyn platform::Window>)>,
     debug_elements_callbacks: HashMap<usize, Box<dyn Fn(&AppContext) -> crate::json::Value>>,
@@ -690,8 +689,7 @@ impl MutableAppContext {
             next_entity_id: 0,
             next_window_id: 0,
             subscriptions: HashMap::new(),
-            model_observations: HashMap::new(),
-            view_observations: HashMap::new(),
+            observations: HashMap::new(),
             presenters_and_platform_windows: HashMap::new(),
             debug_elements_callbacks: HashMap::new(),
             foreground,
@@ -879,6 +877,93 @@ impl MutableAppContext {
         );
     }
 
+    pub fn subscribe_to_model<E, F>(&mut self, handle: &ModelHandle<E>, mut callback: F)
+    where
+        E: Entity,
+        E::Event: 'static,
+        F: 'static + FnMut(ModelHandle<E>, &E::Event, &mut Self),
+    {
+        let emitter_handle = handle.downgrade();
+        self.subscribe(handle, move |payload, cx| {
+            if let Some(emitter_handle) = emitter_handle.upgrade(cx.as_ref()) {
+                callback(emitter_handle, payload, cx);
+            }
+        });
+    }
+
+    pub fn subscribe_to_view<V, F>(&mut self, handle: &ViewHandle<V>, mut callback: F)
+    where
+        V: View,
+        V::Event: 'static,
+        F: 'static + FnMut(ViewHandle<V>, &V::Event, &mut Self),
+    {
+        let emitter_handle = handle.downgrade();
+        self.subscribe(handle, move |payload, cx| {
+            if let Some(emitter_handle) = emitter_handle.upgrade(cx.as_ref()) {
+                callback(emitter_handle, payload, cx);
+            }
+        });
+    }
+
+    pub fn observe_model<E, F>(&mut self, handle: &ModelHandle<E>, mut callback: F)
+    where
+        E: Entity,
+        E::Event: 'static,
+        F: 'static + FnMut(ModelHandle<E>, &mut Self),
+    {
+        let emitter_handle = handle.downgrade();
+        self.observe(handle, move |cx| {
+            if let Some(emitter_handle) = emitter_handle.upgrade(cx.as_ref()) {
+                callback(emitter_handle, cx);
+            }
+        });
+    }
+
+    pub fn observe_view<V, F>(&mut self, handle: &ViewHandle<V>, mut callback: F)
+    where
+        V: View,
+        V::Event: 'static,
+        F: 'static + FnMut(ViewHandle<V>, &mut Self),
+    {
+        let emitter_handle = handle.downgrade();
+        self.observe(handle, move |cx| {
+            if let Some(emitter_handle) = emitter_handle.upgrade(cx.as_ref()) {
+                callback(emitter_handle, cx);
+            }
+        });
+    }
+
+    pub fn subscribe<E, F>(&mut self, handle: &impl Handle<E>, mut callback: F)
+    where
+        E: Entity,
+        E::Event: 'static,
+        F: 'static + FnMut(&E::Event, &mut Self),
+    {
+        self.subscriptions
+            .entry(handle.id())
+            .or_default()
+            .push(Subscription::Global {
+                callback: Box::new(move |payload, cx| {
+                    let payload = payload.downcast_ref().expect("downcast is type safe");
+                    callback(payload, cx);
+                }),
+            });
+    }
+
+    pub fn observe<E, F>(&mut self, handle: &impl Handle<E>, callback: F)
+    where
+        E: Entity,
+        E::Event: 'static,
+        F: 'static + FnMut(&mut Self),
+    {
+        self.observations
+            .entry(handle.id())
+            .or_default()
+            .push(Observation::Global {
+                callback: Box::new(callback),
+            });
+    }
+
     pub(crate) fn notify_view(&mut self, window_id: usize, view_id: usize) {
         self.pending_effects
             .push_back(Effect::ViewNotification { window_id, view_id });
@@ -1184,14 +1269,14 @@ impl MutableAppContext {
 
             for model_id in dropped_models {
                 self.subscriptions.remove(&model_id);
-                self.model_observations.remove(&model_id);
+                self.observations.remove(&model_id);
                 let mut model = self.cx.models.remove(&model_id).unwrap();
                 model.release(self);
             }
 
             for (window_id, view_id) in dropped_views {
                 self.subscriptions.remove(&view_id);
-                self.model_observations.remove(&view_id);
+                self.observations.remove(&view_id);
                 let mut view = self.cx.views.remove(&(window_id, view_id)).unwrap();
                 view.release(self);
                 let change_focus_to = self.cx.windows.get_mut(&window_id).and_then(|window| {
@@ -1281,6 +1366,10 @@ impl MutableAppContext {
         if let Some(subscriptions) = self.subscriptions.remove(&entity_id) {
             for mut subscription in subscriptions {
                 let alive = match &mut subscription {
+                    Subscription::Global { callback } => {
+                        callback(payload.as_ref(), self);
+                        true
+                    }
                     Subscription::FromModel { model_id, callback } => {
                         if let Some(mut model) = self.cx.models.remove(model_id) {
                             callback(model.as_any_mut(), payload.as_ref(), self, *model_id);
@@ -1322,32 +1411,30 @@ impl MutableAppContext {
     }
 
     fn notify_model_observers(&mut self, observed_id: usize) {
-        if let Some(observations) = self.model_observations.remove(&observed_id) {
+        if let Some(observations) = self.observations.remove(&observed_id) {
             if self.cx.models.contains_key(&observed_id) {
                 for mut observation in observations {
                     let alive = match &mut observation {
-                        ModelObservation::FromModel { model_id, callback } => {
+                        Observation::Global { callback } => {
+                            callback(self);
+                            true
+                        }
+                        Observation::FromModel { model_id, callback } => {
                             if let Some(mut model) = self.cx.models.remove(model_id) {
-                                callback(model.as_any_mut(), observed_id, self, *model_id);
+                                callback(model.as_any_mut(), self, *model_id);
                                 self.cx.models.insert(*model_id, model);
                                 true
                             } else {
                                 false
                             }
                         }
-                        ModelObservation::FromView {
+                        Observation::FromView {
                             window_id,
                             view_id,
                             callback,
                         } => {
                             if let Some(mut view) = self.cx.views.remove(&(*window_id, *view_id)) {
-                                callback(
-                                    view.as_any_mut(),
-                                    observed_id,
-                                    self,
-                                    *window_id,
-                                    *view_id,
-                                );
+                                callback(view.as_any_mut(), self, *window_id, *view_id);
                                 self.cx.views.insert((*window_id, *view_id), view);
                                 true
                             } else {
@@ -1357,7 +1444,7 @@ impl MutableAppContext {
                     };
 
                     if alive {
-                        self.model_observations
+                        self.observations
                             .entry(observed_id)
                             .or_default()
                             .push(observation);
@@ -1367,44 +1454,55 @@ impl MutableAppContext {
         }
     }
 
-    fn notify_view_observers(&mut self, window_id: usize, view_id: usize) {
-        if let Some(window) = self.cx.windows.get_mut(&window_id) {
+    fn notify_view_observers(&mut self, observed_window_id: usize, observed_view_id: usize) {
+        if let Some(window) = self.cx.windows.get_mut(&observed_window_id) {
             window
                 .invalidation
                 .get_or_insert_with(Default::default)
                 .updated
-                .insert(view_id);
+                .insert(observed_view_id);
         }
 
-        if let Some(observations) = self.view_observations.remove(&view_id) {
-            if self.cx.views.contains_key(&(window_id, view_id)) {
+        if let Some(observations) = self.observations.remove(&observed_view_id) {
+            if self
+                .cx
+                .views
+                .contains_key(&(observed_window_id, observed_view_id))
+            {
                 for mut observation in observations {
-                    let alive = if let Some(mut view) = self
-                        .cx
-                        .views
-                        .remove(&(observation.window_id, observation.view_id))
+                    if let Observation::FromView {
+                        window_id: observing_window_id,
+                        view_id: observing_view_id,
+                        callback,
+                    } = &mut observation
                     {
-                        (observation.callback)(
-                            view.as_any_mut(),
-                            view_id,
-                            window_id,
-                            self,
-                            observation.window_id,
-                            observation.view_id,
-                        );
-                        self.cx
+                        let alive = if let Some(mut view) = self
+                            .cx
                             .views
-                            .insert((observation.window_id, observation.view_id), view);
-                        true
-                    } else {
-                        false
-                    };
+                            .remove(&(*observing_window_id, *observing_view_id))
+                        {
+                            (callback)(
+                                view.as_any_mut(),
+                                self,
+                                *observing_window_id,
+                                *observing_view_id,
+                            );
+                            self.cx
+                                .views
+                                .insert((*observing_window_id, *observing_view_id), view);
+                            true
+                        } else {
+                            false
+                        };
 
-                    if alive {
-                        self.view_observations
-                            .entry(view_id)
-                            .or_default()
-                            .push(observation);
+                        if alive {
+                            self.observations
+                                .entry(observed_view_id)
+                                .or_default()
+                                .push(observation);
+                        }
+                    } else {
+                        unreachable!()
                     }
                 }
             }
@@ -1901,17 +1999,19 @@ impl<'a, T: Entity> ModelContext<'a, T> {
         S: Entity,
         F: 'static + FnMut(&mut T, ModelHandle<S>, &mut ModelContext<T>),
     {
+        let observed_handle = handle.downgrade();
         self.app
-            .model_observations
+            .observations
             .entry(handle.model_id)
             .or_default()
-            .push(ModelObservation::FromModel {
+            .push(Observation::FromModel {
                 model_id: self.model_id,
-                callback: Box::new(move |model, observed_id, app, model_id| {
-                    let model = model.downcast_mut().expect("downcast is type safe");
-                    let observed = ModelHandle::new(observed_id, &app.cx.ref_counts);
-                    let mut cx = ModelContext::new(app, model_id);
-                    callback(model, observed, &mut cx);
+                callback: Box::new(move |model, app, model_id| {
+                    if let Some(observed) = observed_handle.upgrade(app) {
+                        let model = model.downcast_mut().expect("downcast is type safe");
+                        let mut cx = ModelContext::new(app, model_id);
+                        callback(model, observed, &mut cx);
+                    }
                 }),
             });
     }
@@ -2173,18 +2273,20 @@ impl<'a, T: View> ViewContext<'a, T> {
         S: Entity,
         F: 'static + FnMut(&mut T, ModelHandle<S>, &mut ViewContext<T>),
     {
+        let observed_handle = handle.downgrade();
         self.app
-            .model_observations
+            .observations
             .entry(handle.id())
             .or_default()
-            .push(ModelObservation::FromView {
+            .push(Observation::FromView {
                 window_id: self.window_id,
                 view_id: self.view_id,
-                callback: Box::new(move |view, observed_id, app, window_id, view_id| {
-                    let view = view.downcast_mut().expect("downcast is type safe");
-                    let observed = ModelHandle::new(observed_id, &app.cx.ref_counts);
-                    let mut cx = ViewContext::new(app, window_id, view_id);
-                    callback(view, observed, &mut cx);
+                callback: Box::new(move |view, app, window_id, view_id| {
+                    if let Some(observed) = observed_handle.upgrade(app) {
+                        let view = view.downcast_mut().expect("downcast is type safe");
+                        let mut cx = ViewContext::new(app, window_id, view_id);
+                        callback(view, observed, &mut cx);
+                    }
                 }),
             });
     }
@@ -2194,30 +2296,21 @@ impl<'a, T: View> ViewContext<'a, T> {
         S: View,
         F: 'static + FnMut(&mut T, ViewHandle<S>, &mut ViewContext<T>),
     {
+        let observed_handle = handle.downgrade();
         self.app
-            .view_observations
+            .observations
             .entry(handle.id())
             .or_default()
-            .push(ViewObservation {
+            .push(Observation::FromView {
                 window_id: self.window_id,
                 view_id: self.view_id,
-                callback: Box::new(
-                    move |view,
-                          observed_view_id,
-                          observed_window_id,
-                          app,
-                          observing_window_id,
-                          observing_view_id| {
+                callback: Box::new(move |view, app, observing_window_id, observing_view_id| {
+                    if let Some(observed) = observed_handle.upgrade(app) {
                         let view = view.downcast_mut().expect("downcast is type safe");
-                        let observed_handle = ViewHandle::new(
-                            observed_view_id,
-                            observed_window_id,
-                            &app.cx.ref_counts,
-                        );
                         let mut cx = ViewContext::new(app, observing_window_id, observing_view_id);
-                        callback(view, observed_handle, &mut cx);
-                    },
-                ),
+                        callback(view, observed, &mut cx);
+                    }
+                }),
             });
     }
 
@@ -2402,19 +2495,17 @@ impl<T: Entity> ModelHandle<T> {
         let (tx, mut rx) = mpsc::channel(1024);
 
         let mut cx = cx.cx.borrow_mut();
-        self.update(&mut *cx, |_, cx| {
-            cx.observe(self, {
-                let mut tx = tx.clone();
-                move |_, _, _| {
-                    tx.blocking_send(()).ok();
-                }
-            });
-            cx.subscribe(self, {
-                let mut tx = tx.clone();
-                move |_, _, _| {
-                    tx.blocking_send(()).ok();
-                }
-            })
+        cx.observe_model(self, {
+            let mut tx = tx.clone();
+            move |_, _| {
+                tx.blocking_send(()).ok();
+            }
+        });
+        cx.subscribe_to_model(self, {
+            let mut tx = tx.clone();
+            move |_, _, _| {
+                tx.blocking_send(()).ok();
+            }
         });
 
         let cx = cx.weak_self.as_ref().unwrap().upgrade().unwrap();
@@ -3007,6 +3098,9 @@ impl RefCounts {
 }
 
 enum Subscription {
+    Global {
+        callback: Box<dyn FnMut(&dyn Any, &mut MutableAppContext)>,
+    },
     FromModel {
         model_id: usize,
         callback: Box<dyn FnMut(&mut dyn Any, &dyn Any, &mut MutableAppContext, usize)>,
@@ -3018,24 +3112,21 @@ enum Subscription {
     },
 }
 
-enum ModelObservation {
+enum Observation {
+    Global {
+        callback: Box<dyn FnMut(&mut MutableAppContext)>,
+    },
     FromModel {
         model_id: usize,
-        callback: Box<dyn FnMut(&mut dyn Any, usize, &mut MutableAppContext, usize)>,
+        callback: Box<dyn FnMut(&mut dyn Any, &mut MutableAppContext, usize)>,
     },
     FromView {
         window_id: usize,
         view_id: usize,
-        callback: Box<dyn FnMut(&mut dyn Any, usize, &mut MutableAppContext, usize, usize)>,
+        callback: Box<dyn FnMut(&mut dyn Any, &mut MutableAppContext, usize, usize)>,
     },
 }
 
-struct ViewObservation {
-    window_id: usize,
-    view_id: usize,
-    callback: Box<dyn FnMut(&mut dyn Any, usize, usize, &mut MutableAppContext, usize, usize)>,
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -3099,7 +3190,7 @@ mod tests {
 
         assert_eq!(cx.cx.models.len(), 1);
         assert!(cx.subscriptions.is_empty());
-        assert!(cx.model_observations.is_empty());
+        assert!(cx.observations.is_empty());
     }
 
     #[crate::test(self)]
@@ -3233,7 +3324,7 @@ mod tests {
 
         assert_eq!(cx.cx.views.len(), 2);
         assert!(cx.subscriptions.is_empty());
-        assert!(cx.model_observations.is_empty());
+        assert!(cx.observations.is_empty());
     }
 
     #[crate::test(self)]