Introduce `{MutableAppContext,ViewContext}::observe_actions`

Antonio Scandurra created

Change summary

crates/gpui/src/app.rs | 80 +++++++++++++++++++++++++++++++++++++++++++
1 file changed, 79 insertions(+), 1 deletion(-)

Detailed changes

crates/gpui/src/app.rs 🔗

@@ -759,6 +759,7 @@ type ObservationCallback = Box<dyn FnMut(&mut MutableAppContext) -> bool>;
 type FocusObservationCallback = Box<dyn FnMut(&mut MutableAppContext) -> bool>;
 type GlobalObservationCallback = Box<dyn FnMut(&dyn Any, &mut MutableAppContext)>;
 type ReleaseObservationCallback = Box<dyn FnOnce(&dyn Any, &mut MutableAppContext)>;
+type ActionObservationCallback = Box<dyn FnMut(TypeId, &mut MutableAppContext)>;
 type DeserializeActionCallback = fn(json: &str) -> anyhow::Result<Box<dyn Action>>;
 
 pub struct MutableAppContext {
@@ -784,6 +785,7 @@ pub struct MutableAppContext {
     global_observations:
         Arc<Mutex<HashMap<TypeId, BTreeMap<usize, Option<GlobalObservationCallback>>>>>,
     release_observations: Arc<Mutex<HashMap<usize, BTreeMap<usize, ReleaseObservationCallback>>>>,
+    action_dispatch_observations: Arc<Mutex<BTreeMap<usize, ActionObservationCallback>>>,
     presenters_and_platform_windows:
         HashMap<usize, (Rc<RefCell<Presenter>>, Box<dyn platform::Window>)>,
     foreground: Rc<executor::Foreground>,
@@ -836,6 +838,7 @@ impl MutableAppContext {
             focus_observations: Default::default(),
             release_observations: Default::default(),
             global_observations: Default::default(),
+            action_dispatch_observations: Default::default(),
             presenters_and_platform_windows: HashMap::new(),
             foreground,
             pending_effects: VecDeque::new(),
@@ -1320,6 +1323,20 @@ impl MutableAppContext {
         }
     }
 
+    pub fn observe_actions<F>(&mut self, callback: F) -> Subscription
+    where
+        F: 'static + FnMut(TypeId, &mut MutableAppContext),
+    {
+        let id = post_inc(&mut self.next_subscription_id);
+        self.action_dispatch_observations
+            .lock()
+            .insert(id, Box::new(callback));
+        Subscription::ActionObservation {
+            id,
+            observations: Some(Arc::downgrade(&self.action_dispatch_observations)),
+        }
+    }
+
     pub fn defer(&mut self, callback: impl 'static + FnOnce(&mut MutableAppContext)) {
         self.pending_effects.push_back(Effect::Deferred {
             callback: Box::new(callback),
@@ -1513,6 +1530,11 @@ impl MutableAppContext {
             if !this.halt_action_dispatch {
                 this.halt_action_dispatch = this.dispatch_global_action_any(action);
             }
+
+            this.pending_effects
+                .push_back(Effect::ActionDispatchNotification {
+                    action_id: action.id(),
+                });
             this.halt_action_dispatch
         })
     }
@@ -1961,6 +1983,9 @@ impl MutableAppContext {
                         Effect::RefreshWindows => {
                             refreshing = true;
                         }
+                        Effect::ActionDispatchNotification { action_id } => {
+                            self.handle_action_dispatch_notification_effect(action_id)
+                        }
                     }
                     self.pending_notifications.clear();
                     self.remove_dropped_entities();
@@ -2402,6 +2427,14 @@ impl MutableAppContext {
         })
     }
 
+    fn handle_action_dispatch_notification_effect(&mut self, action_id: TypeId) {
+        let mut callbacks = mem::take(&mut *self.action_dispatch_observations.lock());
+        for (_, callback) in &mut callbacks {
+            callback(action_id, self);
+        }
+        self.action_dispatch_observations.lock().extend(callbacks);
+    }
+
     pub fn focus(&mut self, window_id: usize, view_id: Option<usize>) {
         if let Some(pending_focus_index) = self.pending_focus_index {
             self.pending_effects.remove(pending_focus_index);
@@ -2776,6 +2809,9 @@ pub enum Effect {
         is_active: bool,
     },
     RefreshWindows,
+    ActionDispatchNotification {
+        action_id: TypeId,
+    },
 }
 
 impl Debug for Effect {
@@ -2852,6 +2888,10 @@ impl Debug for Effect {
                 .field("view_id", view_id)
                 .field("subscription_id", subscription_id)
                 .finish(),
+            Effect::ActionDispatchNotification { action_id, .. } => f
+                .debug_struct("Effect::ActionDispatchNotification")
+                .field("action_id", action_id)
+                .finish(),
             Effect::ResizeWindow { window_id } => f
                 .debug_struct("Effect::RefreshWindow")
                 .field("window_id", window_id)
@@ -3376,6 +3416,20 @@ impl<'a, T: View> ViewContext<'a, T> {
         })
     }
 
+    pub fn observe_actions<F>(&mut self, mut callback: F) -> Subscription
+    where
+        F: 'static + FnMut(&mut T, TypeId, &mut ViewContext<T>),
+    {
+        let observer = self.weak_handle();
+        self.app.observe_actions(move |action_id, cx| {
+            if let Some(observer) = observer.upgrade(cx) {
+                observer.update(cx, |observer, cx| {
+                    callback(observer, action_id, cx);
+                });
+            }
+        })
+    }
+
     pub fn emit(&mut self, payload: T::Event) {
         self.app.pending_effects.push_back(Effect::Event {
             entity_id: self.view_id,
@@ -4682,6 +4736,10 @@ pub enum Subscription {
         observations:
             Option<Weak<Mutex<HashMap<usize, BTreeMap<usize, ReleaseObservationCallback>>>>>,
     },
+    ActionObservation {
+        id: usize,
+        observations: Option<Weak<Mutex<BTreeMap<usize, ActionObservationCallback>>>>,
+    },
 }
 
 impl Subscription {
@@ -4705,6 +4763,9 @@ impl Subscription {
             Subscription::FocusObservation { observations, .. } => {
                 observations.take();
             }
+            Subscription::ActionObservation { observations, .. } => {
+                observations.take();
+            }
         }
     }
 }
@@ -4813,6 +4874,11 @@ impl Drop for Subscription {
                     }
                 }
             }
+            Subscription::ActionObservation { id, observations } => {
+                if let Some(observations) = observations.as_ref().and_then(Weak::upgrade) {
+                    observations.lock().remove(&id);
+                }
+            }
         }
     }
 }
@@ -6246,7 +6312,7 @@ mod tests {
             }
         }
 
-        #[derive(Clone, Deserialize)]
+        #[derive(Clone, Default, Deserialize)]
         pub struct Action(pub String);
 
         impl_actions!(test, [Action]);
@@ -6311,6 +6377,13 @@ mod tests {
         let view_3 = cx.add_view(window_id, |_| ViewA { id: 3 });
         let view_4 = cx.add_view(window_id, |_| ViewB { id: 4 });
 
+        let observed_actions = Rc::new(RefCell::new(Vec::new()));
+        cx.observe_actions({
+            let observed_actions = observed_actions.clone();
+            move |action_id, _| observed_actions.borrow_mut().push(action_id)
+        })
+        .detach();
+
         cx.dispatch_action(
             window_id,
             vec![view_1.id(), view_2.id(), view_3.id(), view_4.id()],
@@ -6331,6 +6404,7 @@ mod tests {
                 "1 b"
             ]
         );
+        assert_eq!(*observed_actions.borrow(), [Action::default().id()]);
 
         // Remove view_1, which doesn't propagate the action
         actions.borrow_mut().clear();
@@ -6353,6 +6427,10 @@ mod tests {
                 "global"
             ]
         );
+        assert_eq!(
+            *observed_actions.borrow(),
+            [Action::default().id(), Action::default().id()]
+        );
     }
 
     #[crate::test(self)]