Enforce Send bound on next frame callback functions, because they're invoked on arbitrary threads (#3215)

Nathan Sobo created

This enforces the send bound, then uses an async task looping on a
channel to run callbacks on the main thread.

Change summary

crates/gpui2/src/app.rs                         | 55 +++++++++++++---
crates/gpui2/src/app/async_context.rs           | 10 +-
crates/gpui2/src/app/test_context.rs            |  6 
crates/gpui2/src/elements/img.rs                |  4 
crates/gpui2/src/platform.rs                    |  2 
crates/gpui2/src/platform/mac/display_linker.rs |  4 
crates/gpui2/src/platform/mac/platform.rs       |  2 
crates/gpui2/src/platform/test/platform.rs      |  2 
crates/gpui2/src/window.rs                      | 60 ++++++++++++------
crates/ui2/src/elements/avatar.rs               | 23 ++++++-
crates/ui2/src/elements/player.rs               |  4 
11 files changed, 119 insertions(+), 53 deletions(-)

Detailed changes

crates/gpui2/src/app.rs 🔗

@@ -5,6 +5,7 @@ mod model_context;
 mod test_context;
 
 pub use async_context::*;
+use derive_more::{Deref, DerefMut};
 pub use entity_map::*;
 pub use model_context::*;
 use refineable::Refineable;
@@ -27,7 +28,7 @@ use parking_lot::Mutex;
 use slotmap::SlotMap;
 use std::{
     any::{type_name, Any, TypeId},
-    cell::RefCell,
+    cell::{Ref, RefCell, RefMut},
     marker::PhantomData,
     mem,
     ops::{Deref, DerefMut},
@@ -38,7 +39,31 @@ use std::{
 };
 use util::http::{self, HttpClient};
 
-pub struct App(Rc<RefCell<AppContext>>);
+/// Temporary(?) wrapper around RefCell<AppContext> to help us debug any double borrows.
+/// Strongly consider removing after stabilization.
+pub struct AppCell {
+    app: RefCell<AppContext>,
+}
+
+impl AppCell {
+    pub fn borrow(&self) -> AppRef {
+        AppRef(self.app.borrow())
+    }
+
+    pub fn borrow_mut(&self) -> AppRefMut {
+        // let thread_id = std::thread::current().id();
+        // dbg!("borrowed {thread_id:?}");
+        AppRefMut(self.app.borrow_mut())
+    }
+}
+
+#[derive(Deref, DerefMut)]
+pub struct AppRef<'a>(Ref<'a, AppContext>);
+
+#[derive(Deref, DerefMut)]
+pub struct AppRefMut<'a>(RefMut<'a, AppContext>);
+
+pub struct App(Rc<AppCell>);
 
 /// Represents an application before it is fully launched. Once your app is
 /// configured, you'll start the app with `App::run`.
@@ -112,14 +137,20 @@ impl App {
 }
 
 type ActionBuilder = fn(json: Option<serde_json::Value>) -> anyhow::Result<Box<dyn Action>>;
-type FrameCallback = Box<dyn FnOnce(&mut WindowContext)>;
+pub(crate) type FrameCallback = Box<dyn FnOnce(&mut AppContext)>;
 type Handler = Box<dyn FnMut(&mut AppContext) -> bool + 'static>;
 type Listener = Box<dyn FnMut(&dyn Any, &mut AppContext) -> bool + 'static>;
 type QuitHandler = Box<dyn FnOnce(&mut AppContext) -> LocalBoxFuture<'static, ()> + 'static>;
 type ReleaseListener = Box<dyn FnOnce(&mut dyn Any, &mut AppContext) + 'static>;
 
+// struct FrameConsumer {
+//     next_frame_callbacks: Vec<FrameCallback>,
+//     task: Task<()>,
+//     display_linker
+// }
+
 pub struct AppContext {
-    this: Weak<RefCell<AppContext>>,
+    this: Weak<AppCell>,
     pub(crate) platform: Rc<dyn Platform>,
     app_metadata: AppMetadata,
     text_system: Arc<TextSystem>,
@@ -127,6 +158,7 @@ pub struct AppContext {
     pending_updates: usize,
     pub(crate) active_drag: Option<AnyDrag>,
     pub(crate) next_frame_callbacks: HashMap<DisplayId, Vec<FrameCallback>>,
+    pub(crate) frame_consumers: HashMap<DisplayId, Task<()>>,
     pub(crate) background_executor: BackgroundExecutor,
     pub(crate) foreground_executor: ForegroundExecutor,
     pub(crate) svg_renderer: SvgRenderer,
@@ -157,7 +189,7 @@ impl AppContext {
         platform: Rc<dyn Platform>,
         asset_source: Arc<dyn AssetSource>,
         http_client: Arc<dyn HttpClient>,
-    ) -> Rc<RefCell<Self>> {
+    ) -> Rc<AppCell> {
         let executor = platform.background_executor();
         let foreground_executor = platform.foreground_executor();
         assert!(
@@ -174,15 +206,17 @@ impl AppContext {
             app_version: platform.app_version().ok(),
         };
 
-        Rc::new_cyclic(|this| {
-            RefCell::new(AppContext {
+        Rc::new_cyclic(|this| AppCell {
+            app: RefCell::new(AppContext {
                 this: this.clone(),
-                text_system,
                 platform,
                 app_metadata,
+                text_system,
                 flushing_effects: false,
                 pending_updates: 0,
-                next_frame_callbacks: Default::default(),
+                active_drag: None,
+                next_frame_callbacks: HashMap::default(),
+                frame_consumers: HashMap::default(),
                 background_executor: executor,
                 foreground_executor,
                 svg_renderer: SvgRenderer::new(asset_source.clone()),
@@ -205,8 +239,7 @@ impl AppContext {
                 quit_observers: SubscriberSet::new(),
                 layout_id_buffer: Default::default(),
                 propagate_event: true,
-                active_drag: None,
-            })
+            }),
         })
     }
 

crates/gpui2/src/app/async_context.rs 🔗

@@ -1,15 +1,15 @@
 use crate::{
-    AnyView, AnyWindowHandle, AppContext, BackgroundExecutor, Context, ForegroundExecutor, Model,
-    ModelContext, Render, Result, Task, View, ViewContext, VisualContext, WindowContext,
+    AnyView, AnyWindowHandle, AppCell, AppContext, BackgroundExecutor, Context, ForegroundExecutor,
+    Model, ModelContext, Render, Result, Task, View, ViewContext, VisualContext, WindowContext,
     WindowHandle,
 };
 use anyhow::{anyhow, Context as _};
 use derive_more::{Deref, DerefMut};
-use std::{cell::RefCell, future::Future, rc::Weak};
+use std::{future::Future, rc::Weak};
 
 #[derive(Clone)]
 pub struct AsyncAppContext {
-    pub(crate) app: Weak<RefCell<AppContext>>,
+    pub(crate) app: Weak<AppCell>,
     pub(crate) background_executor: BackgroundExecutor,
     pub(crate) foreground_executor: ForegroundExecutor,
 }
@@ -121,7 +121,7 @@ impl AsyncAppContext {
             .app
             .upgrade()
             .ok_or_else(|| anyhow!("app was released"))?;
-        let app = app.borrow_mut(); // Need this to compile
+        let app = app.borrow_mut();
         Ok(read(app.global(), &app))
     }
 

crates/gpui2/src/app/test_context.rs 🔗

@@ -1,15 +1,15 @@
 use crate::{
-    AnyView, AnyWindowHandle, AppContext, AsyncAppContext, BackgroundExecutor, Context,
+    AnyView, AnyWindowHandle, AppCell, AppContext, AsyncAppContext, BackgroundExecutor, Context,
     EventEmitter, ForegroundExecutor, Model, ModelContext, Result, Task, TestDispatcher,
     TestPlatform, WindowContext,
 };
 use anyhow::{anyhow, bail};
 use futures::{Stream, StreamExt};
-use std::{cell::RefCell, future::Future, rc::Rc, sync::Arc, time::Duration};
+use std::{future::Future, rc::Rc, sync::Arc, time::Duration};
 
 #[derive(Clone)]
 pub struct TestAppContext {
-    pub app: Rc<RefCell<AppContext>>,
+    pub app: Rc<AppCell>,
     pub background_executor: BackgroundExecutor,
     pub foreground_executor: ForegroundExecutor,
 }

crates/gpui2/src/elements/img.rs 🔗

@@ -109,7 +109,9 @@ where
         let corner_radii = style.corner_radii;
 
         if let Some(uri) = self.uri.clone() {
-            let image_future = cx.image_cache.get(uri);
+            // eprintln!(">>> image_cache.get({uri}");
+            let image_future = cx.image_cache.get(uri.clone());
+            // eprintln!("<<< image_cache.get({uri}");
             if let Some(data) = image_future
                 .clone()
                 .now_or_never()

crates/gpui2/src/platform.rs 🔗

@@ -69,7 +69,7 @@ pub(crate) trait Platform: 'static {
     fn set_display_link_output_callback(
         &self,
         display_id: DisplayId,
-        callback: Box<dyn FnMut(&VideoTimestamp, &VideoTimestamp)>,
+        callback: Box<dyn FnMut(&VideoTimestamp, &VideoTimestamp) + Send>,
     );
     fn start_display_link(&self, display_id: DisplayId);
     fn stop_display_link(&self, display_id: DisplayId);

crates/gpui2/src/platform/mac/display_linker.rs 🔗

@@ -26,13 +26,13 @@ impl MacDisplayLinker {
     }
 }
 
-type OutputCallback = Mutex<Box<dyn FnMut(&VideoTimestamp, &VideoTimestamp)>>;
+type OutputCallback = Mutex<Box<dyn FnMut(&VideoTimestamp, &VideoTimestamp) + Send>>;
 
 impl MacDisplayLinker {
     pub fn set_output_callback(
         &mut self,
         display_id: DisplayId,
-        output_callback: Box<dyn FnMut(&VideoTimestamp, &VideoTimestamp)>,
+        output_callback: Box<dyn FnMut(&VideoTimestamp, &VideoTimestamp) + Send>,
     ) {
         if let Some(mut system_link) = unsafe { sys::DisplayLink::on_display(display_id.0) } {
             let callback = Arc::new(Mutex::new(output_callback));

crates/gpui2/src/platform/mac/platform.rs 🔗

@@ -494,7 +494,7 @@ impl Platform for MacPlatform {
     fn set_display_link_output_callback(
         &self,
         display_id: DisplayId,
-        callback: Box<dyn FnMut(&VideoTimestamp, &VideoTimestamp)>,
+        callback: Box<dyn FnMut(&VideoTimestamp, &VideoTimestamp) + Send>,
     ) {
         self.0
             .lock()

crates/gpui2/src/platform/test/platform.rs 🔗

@@ -81,7 +81,7 @@ impl Platform for TestPlatform {
     fn set_display_link_output_callback(
         &self,
         _display_id: DisplayId,
-        _callback: Box<dyn FnMut(&crate::VideoTimestamp, &crate::VideoTimestamp)>,
+        _callback: Box<dyn FnMut(&crate::VideoTimestamp, &crate::VideoTimestamp) + Send>,
     ) {
         unimplemented!()
     }

crates/gpui2/src/window.rs 🔗

@@ -13,7 +13,10 @@ use crate::{
 use anyhow::{anyhow, Result};
 use collections::HashMap;
 use derive_more::{Deref, DerefMut};
-use futures::channel::oneshot;
+use futures::{
+    channel::{mpsc, oneshot},
+    StreamExt,
+};
 use parking_lot::RwLock;
 use slotmap::SlotMap;
 use smallvec::SmallVec;
@@ -435,42 +438,55 @@ impl<'a> WindowContext<'a> {
     }
 
     /// Schedule the given closure to be run directly after the current frame is rendered.
-    pub fn on_next_frame(&mut self, f: impl FnOnce(&mut WindowContext) + 'static) {
-        let f = Box::new(f);
+    pub fn on_next_frame(&mut self, callback: impl FnOnce(&mut WindowContext) + 'static) {
+        let handle = self.window.handle;
         let display_id = self.window.display_id;
 
-        if let Some(callbacks) = self.next_frame_callbacks.get_mut(&display_id) {
-            callbacks.push(f);
-            // If there was already a callback, it means that we already scheduled a frame.
-            if callbacks.len() > 1 {
-                return;
-            }
-        } else {
-            let mut async_cx = self.to_async();
-            self.next_frame_callbacks.insert(display_id, vec![f]);
+        if !self.frame_consumers.contains_key(&display_id) {
+            let (tx, mut rx) = mpsc::unbounded::<()>();
             self.platform.set_display_link_output_callback(
                 display_id,
-                Box::new(move |_current_time, _output_time| {
-                    let _ = async_cx.update(|_, cx| {
-                        let callbacks = cx
+                Box::new(move |_current_time, _output_time| _ = tx.unbounded_send(())),
+            );
+
+            let consumer_task = self.app.spawn(|cx| async move {
+                while rx.next().await.is_some() {
+                    cx.update(|cx| {
+                        for callback in cx
                             .next_frame_callbacks
                             .get_mut(&display_id)
                             .unwrap()
                             .drain(..)
-                            .collect::<Vec<_>>();
-                        for callback in callbacks {
+                            .collect::<SmallVec<[_; 32]>>()
+                        {
                             callback(cx);
                         }
+                    })
+                    .ok();
 
-                        if cx.next_frame_callbacks.get(&display_id).unwrap().is_empty() {
+                    // Flush effects, then stop the display link if no new next_frame_callbacks have been added.
+
+                    cx.update(|cx| {
+                        if cx.next_frame_callbacks.is_empty() {
                             cx.platform.stop_display_link(display_id);
                         }
-                    });
-                }),
-            );
+                    })
+                    .ok();
+                }
+            });
+            self.frame_consumers.insert(display_id, consumer_task);
+        }
+
+        if self.next_frame_callbacks.is_empty() {
+            self.platform.start_display_link(display_id);
         }
 
-        self.platform.start_display_link(display_id);
+        self.next_frame_callbacks
+            .entry(display_id)
+            .or_default()
+            .push(Box::new(move |cx: &mut AppContext| {
+                cx.update_window(handle, |_root_view, cx| callback(cx)).ok();
+            }));
     }
 
     /// Spawn the future returned by the given closure on the application thread pool.

crates/ui2/src/elements/avatar.rs 🔗

@@ -58,11 +58,26 @@ mod stories {
                 .child(Avatar::new(
                     "https://avatars.githubusercontent.com/u/1714999?v=4",
                 ))
+                .child(Avatar::new(
+                    "https://avatars.githubusercontent.com/u/326587?v=4",
+                ))
+                // .child(Avatar::new(
+                //     "https://avatars.githubusercontent.com/u/326587?v=4",
+                // ))
+                // .child(Avatar::new(
+                //     "https://avatars.githubusercontent.com/u/482957?v=4",
+                // ))
+                // .child(Avatar::new(
+                //     "https://avatars.githubusercontent.com/u/1714999?v=4",
+                // ))
+                // .child(Avatar::new(
+                //     "https://avatars.githubusercontent.com/u/1486634?v=4",
+                // ))
                 .child(Story::label(cx, "Rounded rectangle"))
-                .child(
-                    Avatar::new("https://avatars.githubusercontent.com/u/1714999?v=4")
-                        .shape(Shape::RoundedRectangle),
-                )
+            // .child(
+            //     Avatar::new("https://avatars.githubusercontent.com/u/1714999?v=4")
+            //         .shape(Shape::RoundedRectangle),
+            // )
         }
     }
 }

crates/ui2/src/elements/player.rs 🔗

@@ -139,11 +139,11 @@ impl Player {
     }
 
     pub fn cursor_color<V: 'static>(&self, cx: &mut ViewContext<V>) -> Hsla {
-        cx.theme().styles.player.0[self.index].cursor
+        cx.theme().styles.player.0[self.index % cx.theme().styles.player.0.len()].cursor
     }
 
     pub fn selection_color<V: 'static>(&self, cx: &mut ViewContext<V>) -> Hsla {
-        cx.theme().styles.player.0[self.index].selection
+        cx.theme().styles.player.0[self.index % cx.theme().styles.player.0.len()].selection
     }
 
     pub fn avatar_src(&self) -> &str {