In handle ::condition, re-poll on events as well as notifications

Max Brunsfeld created

Change summary

gpui/src/app.rs | 147 +++++++++++++++++++++++++++++---------------------
1 file changed, 84 insertions(+), 63 deletions(-)

Detailed changes

gpui/src/app.rs 🔗

@@ -12,12 +12,12 @@ use keymap::MatchResult;
 use parking_lot::{Mutex, RwLock};
 use pathfinder_geometry::{rect::RectF, vector::vec2f};
 use platform::Event;
-use postage::{sink::Sink as _, stream::Stream as _};
+use postage::{mpsc, sink::Sink as _, stream::Stream as _};
 use smol::prelude::*;
 use std::{
     any::{type_name, Any, TypeId},
     cell::RefCell,
-    collections::{hash_map::Entry, HashMap, HashSet, VecDeque},
+    collections::{HashMap, HashSet, VecDeque},
     fmt::{self, Debug},
     hash::{Hash, Hasher},
     marker::PhantomData,
@@ -388,7 +388,6 @@ pub struct MutableAppContext {
     subscriptions: HashMap<usize, Vec<Subscription>>,
     model_observations: HashMap<usize, Vec<ModelObservation>>,
     view_observations: HashMap<usize, Vec<ViewObservation>>,
-    async_observations: HashMap<usize, postage::broadcast::Sender<()>>,
     window_invalidations: HashMap<usize, WindowInvalidation>,
     presenters_and_platform_windows:
         HashMap<usize, (Rc<RefCell<Presenter>>, Box<dyn platform::Window>)>,
@@ -430,7 +429,6 @@ impl MutableAppContext {
             subscriptions: HashMap::new(),
             model_observations: HashMap::new(),
             view_observations: HashMap::new(),
-            async_observations: HashMap::new(),
             window_invalidations: HashMap::new(),
             presenters_and_platform_windows: HashMap::new(),
             debug_elements_callbacks: HashMap::new(),
@@ -897,13 +895,11 @@ impl MutableAppContext {
                 self.ctx.models.remove(&model_id);
                 self.subscriptions.remove(&model_id);
                 self.model_observations.remove(&model_id);
-                self.async_observations.remove(&model_id);
             }
 
             for (window_id, view_id) in dropped_views {
                 self.subscriptions.remove(&view_id);
                 self.model_observations.remove(&view_id);
-                self.async_observations.remove(&view_id);
                 if let Some(window) = self.ctx.windows.get_mut(&window_id) {
                     self.window_invalidations
                         .entry(window_id)
@@ -1082,12 +1078,6 @@ impl MutableAppContext {
                 }
             }
         }
-
-        if let Entry::Occupied(mut entry) = self.async_observations.entry(observed_id) {
-            if entry.get_mut().blocking_send(()).is_err() {
-                entry.remove_entry();
-            }
-        }
     }
 
     fn notify_view_observers(&mut self, window_id: usize, view_id: usize) {
@@ -1098,7 +1088,12 @@ impl MutableAppContext {
             .insert(view_id);
 
         if let Some(observations) = self.view_observations.remove(&view_id) {
-            if self.ctx.models.contains_key(&view_id) {
+            if self
+                .ctx
+                .windows
+                .get(&window_id)
+                .map_or(false, |w| w.views.contains_key(&view_id))
+            {
                 for mut observation in observations {
                     let alive = if let Some(mut view) = self
                         .ctx
@@ -1134,12 +1129,6 @@ impl MutableAppContext {
                 }
             }
         }
-
-        if let Entry::Occupied(mut entry) = self.async_observations.entry(view_id) {
-            if entry.get_mut().blocking_send(()).is_err() {
-                entry.remove_entry();
-            }
-        }
     }
 
     fn focus(&mut self, window_id: usize, focused_id: usize) {
@@ -1780,6 +1769,10 @@ impl<'a, T: View> ViewContext<'a, T> {
         self.window_id
     }
 
+    pub fn view_id(&self) -> usize {
+        self.view_id
+    }
+
     pub fn foreground(&self) -> &Rc<executor::Foreground> {
         self.app.foreground_executor()
     }
@@ -1855,22 +1848,11 @@ impl<'a, T: View> ViewContext<'a, T> {
         F: 'static + FnMut(&mut T, ModelHandle<E>, &E::Event, &mut ViewContext<T>),
     {
         let emitter_handle = handle.downgrade();
-        self.app
-            .subscriptions
-            .entry(handle.id())
-            .or_default()
-            .push(Subscription::FromView {
-                window_id: self.window_id,
-                view_id: self.view_id,
-                callback: Box::new(move |view, payload, app, window_id, view_id| {
-                    if let Some(emitter_handle) = emitter_handle.upgrade(app.as_ref()) {
-                        let model = view.downcast_mut().expect("downcast is type safe");
-                        let payload = payload.downcast_ref().expect("downcast is type safe");
-                        let mut ctx = ViewContext::new(app, window_id, view_id);
-                        callback(model, emitter_handle, payload, &mut ctx);
-                    }
-                }),
-            });
+        self.subscribe(handle, move |model, payload, ctx| {
+            if let Some(emitter_handle) = emitter_handle.upgrade(ctx.as_ref()) {
+                callback(model, emitter_handle, payload, ctx);
+            }
+        });
     }
 
     pub fn subscribe_to_view<V, F>(&mut self, handle: &ViewHandle<V>, mut callback: F)
@@ -1880,7 +1862,19 @@ impl<'a, T: View> ViewContext<'a, T> {
         F: 'static + FnMut(&mut T, ViewHandle<V>, &V::Event, &mut ViewContext<T>),
     {
         let emitter_handle = handle.downgrade();
+        self.subscribe(handle, move |view, payload, ctx| {
+            if let Some(emitter_handle) = emitter_handle.upgrade(ctx.as_ref()) {
+                callback(view, emitter_handle, payload, ctx);
+            }
+        });
+    }
 
+    pub fn subscribe<E, F>(&mut self, handle: &impl Handle<E>, mut callback: F)
+    where
+        E: Entity,
+        E::Event: 'static,
+        F: 'static + FnMut(&mut T, &E::Event, &mut ViewContext<T>),
+    {
         self.app
             .subscriptions
             .entry(handle.id())
@@ -1888,13 +1882,11 @@ impl<'a, T: View> ViewContext<'a, T> {
             .push(Subscription::FromView {
                 window_id: self.window_id,
                 view_id: self.view_id,
-                callback: Box::new(move |view, payload, app, window_id, view_id| {
-                    if let Some(emitter_handle) = emitter_handle.upgrade(&app) {
-                        let model = view.downcast_mut().expect("downcast is type safe");
-                        let payload = payload.downcast_ref().expect("downcast is type safe");
-                        let mut ctx = ViewContext::new(app, window_id, view_id);
-                        callback(model, emitter_handle, payload, &mut ctx);
-                    }
+                callback: Box::new(move |entity, payload, app, window_id, view_id| {
+                    let entity = entity.downcast_mut().expect("downcast is type safe");
+                    let payload = payload.downcast_ref().expect("downcast is type safe");
+                    let mut ctx = ViewContext::new(app, window_id, view_id);
+                    callback(entity, payload, &mut ctx);
                 }),
             });
     }
@@ -2138,12 +2130,24 @@ impl<T: Entity> ModelHandle<T> {
         ctx: &TestAppContext,
         mut predicate: impl FnMut(&T, &AppContext) -> bool,
     ) -> impl Future<Output = ()> {
+        let (tx, mut rx) = mpsc::channel(1024);
+
         let mut ctx = ctx.0.borrow_mut();
-        let tx = ctx
-            .async_observations
-            .entry(self.id())
-            .or_insert_with(|| postage::broadcast::channel(128).0);
-        let mut rx = tx.subscribe();
+        self.update(&mut *ctx, |_, ctx| {
+            ctx.observe(self, {
+                let mut tx = tx.clone();
+                move |_, _, _| {
+                    tx.blocking_send(()).ok();
+                }
+            });
+            ctx.subscribe(self, {
+                let mut tx = tx.clone();
+                move |_, _, _| {
+                    tx.blocking_send(()).ok();
+                }
+            })
+        });
+
         let ctx = ctx.weak_self.as_ref().unwrap().upgrade().unwrap();
         let handle = self.downgrade();
 
@@ -2310,19 +2314,41 @@ impl<T: View> ViewHandle<T> {
     pub fn condition(
         &self,
         ctx: &TestAppContext,
-        mut predicate: impl 'static + FnMut(&T, &AppContext) -> bool,
-    ) -> impl 'static + Future<Output = ()> {
+        predicate: impl FnMut(&T, &AppContext) -> bool,
+    ) -> impl Future<Output = ()> {
+        self.condition_with_duration(Duration::from_millis(500), ctx, predicate)
+    }
+
+    pub fn condition_with_duration(
+        &self,
+        duration: Duration,
+        ctx: &TestAppContext,
+        mut predicate: impl FnMut(&T, &AppContext) -> bool,
+    ) -> impl Future<Output = ()> {
+        let (tx, mut rx) = mpsc::channel(1024);
+
         let mut ctx = ctx.0.borrow_mut();
-        let tx = ctx
-            .async_observations
-            .entry(self.id())
-            .or_insert_with(|| postage::broadcast::channel(128).0);
-        let mut rx = tx.subscribe();
+        self.update(&mut *ctx, |_, ctx| {
+            ctx.observe_view(self, {
+                let mut tx = tx.clone();
+                move |_, _, _| {
+                    tx.blocking_send(()).ok();
+                }
+            });
+
+            ctx.subscribe(self, {
+                let mut tx = tx.clone();
+                move |_, _, _| {
+                    tx.blocking_send(()).ok();
+                }
+            })
+        });
+
         let ctx = ctx.weak_self.as_ref().unwrap().upgrade().unwrap();
         let handle = self.downgrade();
 
         async move {
-            timeout(Duration::from_millis(200), async move {
+            timeout(duration, async move {
                 loop {
                     {
                         let ctx = ctx.borrow();
@@ -2330,7 +2356,7 @@ impl<T: View> ViewHandle<T> {
                         if predicate(
                             handle
                                 .upgrade(ctx)
-                                .expect("model dropped with pending condition")
+                                .expect("view dropped with pending condition")
                                 .read(ctx),
                             ctx,
                         ) {
@@ -2340,7 +2366,7 @@ impl<T: View> ViewHandle<T> {
 
                     rx.recv()
                         .await
-                        .expect("model dropped with pending condition");
+                        .expect("view dropped with pending condition");
                 }
             })
             .await
@@ -3537,9 +3563,7 @@ mod tests {
             model.update(&mut app, |model, ctx| model.inc(ctx));
             assert_eq!(poll_once(&mut condition2).await, Some(()));
 
-            // Broadcast channel should be removed if no conditions remain on next notification.
             model.update(&mut app, |_, ctx| ctx.notify());
-            app.update(|ctx| assert!(ctx.async_observations.get(&model.id()).is_none()));
         });
     }
 
@@ -3617,10 +3641,7 @@ mod tests {
 
             view.update(&mut app, |view, ctx| view.inc(ctx));
             assert_eq!(poll_once(&mut condition2).await, Some(()));
-
-            // Broadcast channel should be removed if no conditions remain on next notification.
             view.update(&mut app, |_, ctx| ctx.notify());
-            app.update(|ctx| assert!(ctx.async_observations.get(&view.id()).is_none()));
         });
     }
 
@@ -3650,7 +3671,7 @@ mod tests {
     }
 
     #[test]
-    #[should_panic(expected = "model dropped with pending condition")]
+    #[should_panic(expected = "view dropped with pending condition")]
     fn test_view_condition_panic_on_drop() {
         struct View;