Merge pull request #657 from zed-industries/global-observations

Max Brunsfeld created

Add global change observations

Change summary

crates/gpui/src/app.rs | 244 +++++++++++++++++++++++++++++++++++++++----
1 file changed, 219 insertions(+), 25 deletions(-)

Detailed changes

crates/gpui/src/app.rs 🔗

@@ -763,6 +763,7 @@ type GlobalActionCallback = dyn FnMut(&dyn AnyAction, &mut MutableAppContext);
 type SubscriptionCallback = Box<dyn FnMut(&dyn Any, &mut MutableAppContext) -> bool>;
 type GlobalSubscriptionCallback = Box<dyn FnMut(&dyn Any, &mut MutableAppContext)>;
 type ObservationCallback = Box<dyn FnMut(&mut MutableAppContext) -> bool>;
+type GlobalObservationCallback = Box<dyn FnMut(&mut MutableAppContext)>;
 type ReleaseObservationCallback = Box<dyn FnMut(&dyn Any, &mut MutableAppContext)>;
 
 pub struct MutableAppContext {
@@ -782,12 +783,15 @@ pub struct MutableAppContext {
     global_subscriptions:
         Arc<Mutex<HashMap<TypeId, BTreeMap<usize, Option<GlobalSubscriptionCallback>>>>>,
     observations: Arc<Mutex<HashMap<usize, BTreeMap<usize, Option<ObservationCallback>>>>>,
+    global_observations:
+        Arc<Mutex<HashMap<TypeId, BTreeMap<usize, Option<GlobalObservationCallback>>>>>,
     release_observations: Arc<Mutex<HashMap<usize, BTreeMap<usize, ReleaseObservationCallback>>>>,
     presenters_and_platform_windows:
         HashMap<usize, (Rc<RefCell<Presenter>>, Box<dyn platform::Window>)>,
     foreground: Rc<executor::Foreground>,
     pending_effects: VecDeque<Effect>,
     pending_notifications: HashSet<usize>,
+    pending_global_notifications: HashSet<TypeId>,
     pending_flushes: usize,
     flushing_effects: bool,
     next_cursor_style_handle_id: Arc<AtomicUsize>,
@@ -831,10 +835,12 @@ impl MutableAppContext {
             global_subscriptions: Default::default(),
             observations: Default::default(),
             release_observations: Default::default(),
+            global_observations: Default::default(),
             presenters_and_platform_windows: HashMap::new(),
             foreground,
             pending_effects: VecDeque::new(),
             pending_notifications: HashSet::new(),
+            pending_global_notifications: HashSet::new(),
             pending_flushes: 0,
             flushing_effects: false,
             next_cursor_style_handle_id: Default::default(),
@@ -1197,6 +1203,27 @@ impl MutableAppContext {
         }
     }
 
+    pub fn observe_global<G, F>(&mut self, observe: F) -> Subscription
+    where
+        G: Any,
+        F: 'static + FnMut(&mut MutableAppContext),
+    {
+        let type_id = TypeId::of::<G>();
+        let id = post_inc(&mut self.next_subscription_id);
+
+        self.global_observations
+            .lock()
+            .entry(type_id)
+            .or_default()
+            .insert(id, Some(Box::new(observe)));
+
+        Subscription::GlobalObservation {
+            id,
+            type_id,
+            observations: Some(Arc::downgrade(&self.global_observations)),
+        }
+    }
+
     pub fn observe_release<E, H, F>(&mut self, handle: &H, mut callback: F) -> Subscription
     where
         E: Entity,
@@ -1251,6 +1278,13 @@ impl MutableAppContext {
         }
     }
 
+    pub(crate) fn notify_global(&mut self, type_id: TypeId) {
+        if self.pending_global_notifications.insert(type_id) {
+            self.pending_effects
+                .push_back(Effect::GlobalNotification { type_id });
+        }
+    }
+
     pub fn dispatch_action<A: Action>(
         &mut self,
         window_id: usize,
@@ -1380,16 +1414,26 @@ impl MutableAppContext {
     }
 
     pub fn default_global<T: 'static + Default>(&mut self) -> &T {
-        self.cx
-            .globals
-            .entry(TypeId::of::<T>())
-            .or_insert_with(|| Box::new(T::default()))
-            .downcast_ref()
-            .unwrap()
+        let type_id = TypeId::of::<T>();
+        self.update(|this| {
+            if !this.globals.contains_key(&type_id) {
+                this.notify_global(type_id);
+            }
+
+            this.cx
+                .globals
+                .entry(type_id)
+                .or_insert_with(|| Box::new(T::default()));
+        });
+        self.globals.get(&type_id).unwrap().downcast_ref().unwrap()
     }
 
     pub fn set_global<T: 'static>(&mut self, state: T) {
-        self.cx.globals.insert(TypeId::of::<T>(), Box::new(state));
+        self.update(|this| {
+            let type_id = TypeId::of::<T>();
+            this.cx.globals.insert(type_id, Box::new(state));
+            this.notify_global(type_id);
+        });
     }
 
     pub fn update_default_global<T, F, U>(&mut self, update: F) -> U
@@ -1397,15 +1441,18 @@ impl MutableAppContext {
         T: 'static + Default,
         F: FnOnce(&mut T, &mut MutableAppContext) -> U,
     {
-        let type_id = TypeId::of::<T>();
-        let mut state = self
-            .cx
-            .globals
-            .remove(&type_id)
-            .unwrap_or_else(|| Box::new(T::default()));
-        let result = update(state.downcast_mut().unwrap(), self);
-        self.cx.globals.insert(type_id, state);
-        result
+        self.update(|this| {
+            let type_id = TypeId::of::<T>();
+            let mut state = this
+                .cx
+                .globals
+                .remove(&type_id)
+                .unwrap_or_else(|| Box::new(T::default()));
+            let result = update(state.downcast_mut().unwrap(), this);
+            this.cx.globals.insert(type_id, state);
+            this.notify_global(type_id);
+            result
+        })
     }
 
     pub fn update_global<T, F, U>(&mut self, update: F) -> U
@@ -1413,15 +1460,18 @@ impl MutableAppContext {
         T: 'static,
         F: FnOnce(&mut T, &mut MutableAppContext) -> U,
     {
-        let type_id = TypeId::of::<T>();
-        let mut state = self
-            .cx
-            .globals
-            .remove(&type_id)
-            .expect("no global has been added for this type");
-        let result = update(state.downcast_mut().unwrap(), self);
-        self.cx.globals.insert(type_id, state);
-        result
+        self.update(|this| {
+            let type_id = TypeId::of::<T>();
+            let mut state = this
+                .cx
+                .globals
+                .remove(&type_id)
+                .expect("no global has been added for this type");
+            let result = update(state.downcast_mut().unwrap(), this);
+            this.cx.globals.insert(type_id, state);
+            this.notify_global(type_id);
+            result
+        })
     }
 
     pub fn add_model<T, F>(&mut self, build_model: F) -> ModelHandle<T>
@@ -1686,6 +1736,9 @@ impl MutableAppContext {
                         Effect::ViewNotification { window_id, view_id } => {
                             self.notify_view_observers(window_id, view_id)
                         }
+                        Effect::GlobalNotification { type_id } => {
+                            self.notify_global_observers(type_id)
+                        }
                         Effect::Deferred {
                             callback,
                             after_window_update,
@@ -1734,6 +1787,7 @@ impl MutableAppContext {
                         if self.pending_effects.is_empty() {
                             self.flushing_effects = false;
                             self.pending_notifications.clear();
+                            self.pending_global_notifications.clear();
                             break;
                         }
                     }
@@ -2004,6 +2058,33 @@ impl MutableAppContext {
         }
     }
 
+    fn notify_global_observers(&mut self, observed_type_id: TypeId) {
+        let callbacks = self.global_observations.lock().remove(&observed_type_id);
+        if let Some(callbacks) = callbacks {
+            if self.cx.globals.contains_key(&observed_type_id) {
+                for (id, callback) in callbacks {
+                    if let Some(mut callback) = callback {
+                        callback(self);
+                        match self
+                            .global_observations
+                            .lock()
+                            .entry(observed_type_id)
+                            .or_default()
+                            .entry(id)
+                        {
+                            collections::btree_map::Entry::Vacant(entry) => {
+                                entry.insert(Some(callback));
+                            }
+                            collections::btree_map::Entry::Occupied(entry) => {
+                                entry.remove();
+                            }
+                        }
+                    }
+                }
+            }
+        }
+    }
+
     fn notify_release_observers(&mut self, entity_id: usize, entity: &dyn Any) {
         let callbacks = self.release_observations.lock().remove(&entity_id);
         if let Some(callbacks) = callbacks {
@@ -2377,6 +2458,9 @@ pub enum Effect {
         callback: Box<dyn FnOnce(&mut MutableAppContext)>,
         after_window_update: bool,
     },
+    GlobalNotification {
+        type_id: TypeId,
+    },
     ModelRelease {
         model_id: usize,
         model: Box<dyn AnyModel>,
@@ -2442,6 +2526,10 @@ impl Debug for Effect {
                 .field("window_id", window_id)
                 .field("view_id", view_id)
                 .finish(),
+            Effect::GlobalNotification { type_id } => f
+                .debug_struct("Effect::GlobalNotification")
+                .field("type_id", type_id)
+                .finish(),
             Effect::Deferred { .. } => f.debug_struct("Effect::Deferred").finish(),
             Effect::ModelRelease { model_id, .. } => f
                 .debug_struct("Effect::ModelRelease")
@@ -2621,6 +2709,15 @@ impl<'a, T: Entity> ModelContext<'a, T> {
         self.app.add_model(build_model)
     }
 
+    pub fn defer(&mut self, callback: impl 'static + FnOnce(&mut T, &mut ModelContext<T>)) {
+        let handle = self.handle();
+        self.app.defer(Box::new(move |cx| {
+            handle.update(cx, |model, cx| {
+                callback(model, cx);
+            })
+        }))
+    }
+
     pub fn emit(&mut self, payload: T::Event) {
         self.app.pending_effects.push_back(Effect::Event {
             entity_id: self.model_id,
@@ -4178,6 +4275,13 @@ pub enum Subscription {
         observations:
             Option<Weak<Mutex<HashMap<usize, BTreeMap<usize, Option<ObservationCallback>>>>>>,
     },
+    GlobalObservation {
+        id: usize,
+        type_id: TypeId,
+        observations: Option<
+            Weak<Mutex<HashMap<TypeId, BTreeMap<usize, Option<GlobalObservationCallback>>>>>,
+        >,
+    },
     ReleaseObservation {
         id: usize,
         entity_id: usize,
@@ -4198,6 +4302,9 @@ impl Subscription {
             Subscription::Observation { observations, .. } => {
                 observations.take();
             }
+            Subscription::GlobalObservation { observations, .. } => {
+                observations.take();
+            }
             Subscription::ReleaseObservation { observations, .. } => {
                 observations.take();
             }
@@ -4266,6 +4373,22 @@ impl Drop for Subscription {
                     }
                 }
             }
+            Subscription::GlobalObservation {
+                id,
+                type_id,
+                observations,
+            } => {
+                if let Some(observations) = observations.as_ref().and_then(Weak::upgrade) {
+                    match observations.lock().entry(*type_id).or_default().entry(*id) {
+                        collections::btree_map::Entry::Vacant(entry) => {
+                            entry.insert(None);
+                        }
+                        collections::btree_map::Entry::Occupied(entry) => {
+                            entry.remove();
+                        }
+                    }
+                }
+            }
             Subscription::ReleaseObservation {
                 id,
                 entity_id,
@@ -5088,6 +5211,61 @@ mod tests {
         );
     }
 
+    #[crate::test(self)]
+    fn test_global(cx: &mut MutableAppContext) {
+        type Global = usize;
+
+        let observation_count = Rc::new(RefCell::new(0));
+        let subscription = cx.observe_global::<Global, _>({
+            let observation_count = observation_count.clone();
+            move |_| {
+                *observation_count.borrow_mut() += 1;
+            }
+        });
+
+        assert!(!cx.has_global::<Global>());
+        assert_eq!(cx.default_global::<Global>(), &0);
+        assert_eq!(*observation_count.borrow(), 1);
+        assert!(cx.has_global::<Global>());
+        assert_eq!(
+            cx.update_global::<Global, _, _>(|global, _| {
+                *global = 1;
+                "Update Result"
+            }),
+            "Update Result"
+        );
+        assert_eq!(*observation_count.borrow(), 2);
+        assert_eq!(cx.global::<Global>(), &1);
+
+        drop(subscription);
+        cx.update_global::<Global, _, _>(|global, _| {
+            *global = 2;
+        });
+        assert_eq!(*observation_count.borrow(), 2);
+
+        type OtherGlobal = f32;
+
+        let observation_count = Rc::new(RefCell::new(0));
+        cx.observe_global::<OtherGlobal, _>({
+            let observation_count = observation_count.clone();
+            move |_| {
+                *observation_count.borrow_mut() += 1;
+            }
+        })
+        .detach();
+
+        assert_eq!(
+            cx.update_default_global::<OtherGlobal, _, _>(|global, _| {
+                assert_eq!(global, &0.0);
+                *global = 2.0;
+                "Default update result"
+            }),
+            "Default update result"
+        );
+        assert_eq!(cx.global::<OtherGlobal>(), &2.0);
+        assert_eq!(*observation_count.borrow(), 1);
+    }
+
     #[crate::test(self)]
     fn test_dropping_subscribers(cx: &mut MutableAppContext) {
         struct View;
@@ -5437,6 +5615,22 @@ mod tests {
         });
 
         assert_eq!(*observation_count.borrow(), 1);
+
+        // Global Observation
+        let observation_count = Rc::new(RefCell::new(0));
+        let subscription = Rc::new(RefCell::new(None));
+        *subscription.borrow_mut() = Some(cx.observe_global::<(), _>({
+            let observation_count = observation_count.clone();
+            let subscription = subscription.clone();
+            move |_| {
+                subscription.borrow_mut().take();
+                *observation_count.borrow_mut() += 1;
+            }
+        }));
+
+        cx.default_global::<()>();
+        cx.set_global(());
+        assert_eq!(*observation_count.borrow(), 1);
     }
 
     #[crate::test(self)]