Allow observing the release of entities

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/gpui/src/app.rs | 148 +++++++++++++++++++++++++++++++++++++++----
1 file changed, 134 insertions(+), 14 deletions(-)

Detailed changes

crates/gpui/src/app.rs 🔗

@@ -660,6 +660,7 @@ type GlobalActionCallback = dyn FnMut(&dyn AnyAction, &mut MutableAppContext);
 
 type SubscriptionCallback = Box<dyn FnMut(&dyn Any, &mut MutableAppContext) -> bool>;
 type ObservationCallback = Box<dyn FnMut(&mut MutableAppContext) -> bool>;
+type ReleaseObservationCallback = Box<dyn FnMut(&mut MutableAppContext)>;
 
 pub struct MutableAppContext {
     weak_self: Option<rc::Weak<RefCell<Self>>>,
@@ -674,6 +675,7 @@ pub struct MutableAppContext {
     next_subscription_id: usize,
     subscriptions: Arc<Mutex<HashMap<usize, BTreeMap<usize, SubscriptionCallback>>>>,
     observations: Arc<Mutex<HashMap<usize, BTreeMap<usize, ObservationCallback>>>>,
+    release_observations: Arc<Mutex<HashMap<usize, BTreeMap<usize, ReleaseObservationCallback>>>>,
     presenters_and_platform_windows:
         HashMap<usize, (Rc<RefCell<Presenter>>, Box<dyn platform::Window>)>,
     debug_elements_callbacks: HashMap<usize, Box<dyn Fn(&AppContext) -> crate::json::Value>>,
@@ -717,6 +719,7 @@ impl MutableAppContext {
             next_subscription_id: 0,
             subscriptions: Default::default(),
             observations: Default::default(),
+            release_observations: Default::default(),
             presenters_and_platform_windows: HashMap::new(),
             debug_elements_callbacks: HashMap::new(),
             foreground,
@@ -1030,6 +1033,27 @@ impl MutableAppContext {
             observations: Some(Arc::downgrade(&self.observations)),
         }
     }
+
+    pub fn observe_release<E, H, F>(&mut self, handle: &H, mut callback: F) -> Subscription
+    where
+        E: Entity,
+        E::Event: 'static,
+        H: Handle<E>,
+        F: 'static + FnMut(&mut Self),
+    {
+        let id = post_inc(&mut self.next_subscription_id);
+        self.release_observations
+            .lock()
+            .entry(handle.id())
+            .or_default()
+            .insert(id, Box::new(move |cx| callback(cx)));
+        Subscription::ReleaseObservation {
+            id,
+            entity_id: handle.id(),
+            observations: Some(Arc::downgrade(&self.release_observations)),
+        }
+    }
+
     pub(crate) fn notify_model(&mut self, model_id: usize) {
         if self.pending_notifications.insert(model_id) {
             self.pending_effects
@@ -1208,6 +1232,7 @@ impl MutableAppContext {
         self.cx.windows.remove(&window_id);
         self.presenters_and_platform_windows.remove(&window_id);
         self.remove_dropped_entities();
+        self.flush_effects();
     }
 
     fn open_platform_window(&mut self, window_id: usize, window_options: WindowOptions) {
@@ -1358,6 +1383,9 @@ impl MutableAppContext {
                 self.observations.lock().remove(&model_id);
                 let mut model = self.cx.models.remove(&model_id).unwrap();
                 model.release(self);
+                self.pending_effects.push_back(Effect::Release {
+                    entity_id: model_id,
+                });
             }
 
             for (window_id, view_id) in dropped_views {
@@ -1381,6 +1409,9 @@ impl MutableAppContext {
                 if let Some(view_id) = change_focus_to {
                     self.focus(window_id, view_id);
                 }
+
+                self.pending_effects
+                    .push_back(Effect::Release { entity_id: view_id });
             }
 
             for key in dropped_element_states {
@@ -1406,6 +1437,7 @@ impl MutableAppContext {
                         Effect::ViewNotification { window_id, view_id } => {
                             self.notify_view_observers(window_id, view_id)
                         }
+                        Effect::Release { entity_id } => self.notify_release_observers(entity_id),
                         Effect::Focus { window_id, view_id } => {
                             self.focus(window_id, view_id);
                         }
@@ -1568,6 +1600,15 @@ impl MutableAppContext {
         }
     }
 
+    fn notify_release_observers(&mut self, entity_id: usize) {
+        let callbacks = self.release_observations.lock().remove(&entity_id);
+        if let Some(callbacks) = callbacks {
+            for (_, mut callback) in callbacks {
+                callback(self);
+            }
+        }
+    }
+
     fn focus(&mut self, window_id: usize, focused_id: usize) {
         if self
             .cx
@@ -1824,6 +1865,9 @@ pub enum Effect {
         window_id: usize,
         view_id: usize,
     },
+    Release {
+        entity_id: usize,
+    },
     Focus {
         window_id: usize,
         view_id: usize,
@@ -1850,6 +1894,10 @@ impl Debug for Effect {
                 .field("window_id", window_id)
                 .field("view_id", view_id)
                 .finish(),
+            Effect::Release { entity_id } => f
+                .debug_struct("Effect::Release")
+                .field("entity_id", entity_id)
+                .finish(),
             Effect::Focus { window_id, view_id } => f
                 .debug_struct("Effect::Focus")
                 .field("window_id", window_id)
@@ -2072,6 +2120,25 @@ impl<'a, T: Entity> ModelContext<'a, T> {
         })
     }
 
+    pub fn observe_release<S, F>(
+        &mut self,
+        handle: &ModelHandle<S>,
+        mut callback: F,
+    ) -> Subscription
+    where
+        S: Entity,
+        F: 'static + FnMut(&mut T, &mut ModelContext<T>),
+    {
+        let observer = self.weak_handle();
+        self.app.observe_release(handle, move |cx| {
+            if let Some(observer) = observer.upgrade(cx) {
+                observer.update(cx, |observer, cx| {
+                    callback(observer, cx);
+                });
+            }
+        })
+    }
+
     pub fn handle(&self) -> ModelHandle<T> {
         ModelHandle::new(self.model_id, &self.app.cx.ref_counts)
     }
@@ -2305,6 +2372,22 @@ impl<'a, T: View> ViewContext<'a, T> {
         })
     }
 
+    pub fn observe_release<E, F, H>(&mut self, handle: &H, mut callback: F) -> Subscription
+    where
+        E: Entity,
+        H: Handle<E>,
+        F: 'static + FnMut(&mut T, &mut ViewContext<T>),
+    {
+        let observer = self.weak_handle();
+        self.app.observe_release(handle, move |cx| {
+            if let Some(observer) = observer.upgrade(cx) {
+                observer.update(cx, |observer, cx| {
+                    callback(observer, cx);
+                });
+            }
+        })
+    }
+
     pub fn emit(&mut self, payload: T::Event) {
         self.app.pending_effects.push_back(Effect::Event {
             entity_id: self.view_id,
@@ -3263,6 +3346,12 @@ pub enum Subscription {
         entity_id: usize,
         observations: Option<Weak<Mutex<HashMap<usize, BTreeMap<usize, ObservationCallback>>>>>,
     },
+    ReleaseObservation {
+        id: usize,
+        entity_id: usize,
+        observations:
+            Option<Weak<Mutex<HashMap<usize, BTreeMap<usize, ReleaseObservationCallback>>>>>,
+    },
 }
 
 impl Subscription {
@@ -3274,6 +3363,9 @@ impl Subscription {
             Subscription::Observation { observations, .. } => {
                 observations.take();
             }
+            Subscription::ReleaseObservation { observations, .. } => {
+                observations.take();
+            }
         }
     }
 }
@@ -3292,6 +3384,17 @@ impl Drop for Subscription {
                     }
                 }
             }
+            Subscription::ReleaseObservation {
+                id,
+                entity_id,
+                observations,
+            } => {
+                if let Some(observations) = observations.as_ref().and_then(Weak::upgrade) {
+                    if let Some(observations) = observations.lock().get_mut(entity_id) {
+                        observations.remove(id);
+                    }
+                }
+            }
             Subscription::Subscription {
                 id,
                 entity_id,
@@ -3401,7 +3504,10 @@ mod tests {
     use super::*;
     use crate::elements::*;
     use smol::future::poll_once;
-    use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
+    use std::{
+        cell::Cell,
+        sync::atomic::{AtomicUsize, Ordering::SeqCst},
+    };
 
     #[crate::test(self)]
     fn test_model_handles(cx: &mut MutableAppContext) {
@@ -3652,18 +3758,18 @@ mod tests {
     #[crate::test(self)]
     fn test_entity_release_hooks(cx: &mut MutableAppContext) {
         struct Model {
-            released: Arc<Mutex<bool>>,
+            released: Rc<Cell<bool>>,
         }
 
         struct View {
-            released: Arc<Mutex<bool>>,
+            released: Rc<Cell<bool>>,
         }
 
         impl Entity for Model {
             type Event = ();
 
             fn release(&mut self, _: &mut MutableAppContext) {
-                *self.released.lock() = true;
+                self.released.set(true);
             }
         }
 
@@ -3671,7 +3777,7 @@ mod tests {
             type Event = ();
 
             fn release(&mut self, _: &mut MutableAppContext) {
-                *self.released.lock() = true;
+                self.released.set(true);
             }
         }
 
@@ -3685,27 +3791,41 @@ mod tests {
             }
         }
 
-        let model_released = Arc::new(Mutex::new(false));
-        let view_released = Arc::new(Mutex::new(false));
+        let model_released = Rc::new(Cell::new(false));
+        let model_release_observed = Rc::new(Cell::new(false));
+        let view_released = Rc::new(Cell::new(false));
+        let view_release_observed = Rc::new(Cell::new(false));
 
         let model = cx.add_model(|_| Model {
             released: model_released.clone(),
         });
-
-        let (window_id, _) = cx.add_window(Default::default(), |_| View {
+        let (window_id, view) = cx.add_window(Default::default(), |_| View {
             released: view_released.clone(),
         });
+        assert!(!model_released.get());
+        assert!(!view_released.get());
 
-        assert!(!*model_released.lock());
-        assert!(!*view_released.lock());
+        cx.observe_release(&model, {
+            let model_release_observed = model_release_observed.clone();
+            move |_| model_release_observed.set(true)
+        })
+        .detach();
+        cx.observe_release(&view, {
+            let view_release_observed = view_release_observed.clone();
+            move |_| view_release_observed.set(true)
+        })
+        .detach();
 
         cx.update(move |_| {
             drop(model);
         });
-        assert!(*model_released.lock());
+        assert!(model_released.get());
+        assert!(model_release_observed.get());
 
-        drop(cx.remove_window(window_id));
-        assert!(*view_released.lock());
+        drop(view);
+        cx.remove_window(window_id);
+        assert!(view_released.get());
+        assert!(view_release_observed.get());
     }
 
     #[crate::test(self)]