gpui: Cancel foreground tasks when the app is dropped (#45768)

Mikayla Maki created

This is in preparation for removing `Result<T>` from the `AsyncApp`
methods

Refactor machine goes brrrrrr

Plan and tracker for this PR:
https://gist.github.com/mikayla-maki/7dfc0d4907e76de119b5712e24665f02

This PR should be safe to merge, as I cannot observe any changes in
behavior from adding this to our application.

Release Notes:

- N/A

Change summary

crates/gpui/src/app.rs                      |   3 
crates/gpui/src/app/async_context.rs        |   8 
crates/gpui/src/app/test_context.rs         |   1 
crates/gpui/src/executor.rs                 | 381 ++++++++++++++++++++++
crates/gpui/src/platform.rs                 |  22 +
crates/gpui/src/platform/mac/dispatcher.rs  |   8 
crates/gpui/src/platform/test/dispatcher.rs |   9 
7 files changed, 420 insertions(+), 12 deletions(-)

Detailed changes

crates/gpui/src/app.rs 🔗

@@ -580,6 +580,7 @@ impl GpuiMode {
 /// You need a reference to an `App` to access the state of a [Entity].
 pub struct App {
     pub(crate) this: Weak<AppCell>,
+    pub(crate) liveness: std::sync::Arc<()>,
     pub(crate) platform: Rc<dyn Platform>,
     pub(crate) mode: GpuiMode,
     text_system: Arc<TextSystem>,
@@ -658,6 +659,7 @@ impl App {
         let app = Rc::new_cyclic(|this| AppCell {
             app: RefCell::new(App {
                 this: this.clone(),
+                liveness: std::sync::Arc::new(()),
                 platform: platform.clone(),
                 text_system,
                 mode: GpuiMode::Production,
@@ -1476,6 +1478,7 @@ impl App {
     pub fn to_async(&self) -> AsyncApp {
         AsyncApp {
             app: self.this.clone(),
+            liveness_token: std::sync::Arc::downgrade(&self.liveness),
             background_executor: self.background_executor.clone(),
             foreground_executor: self.foreground_executor.clone(),
         }

crates/gpui/src/app/async_context.rs 🔗

@@ -16,6 +16,7 @@ use super::{Context, WeakEntity};
 #[derive(Clone)]
 pub struct AsyncApp {
     pub(crate) app: Weak<AppCell>,
+    pub(crate) liveness_token: std::sync::Weak<()>,
     pub(crate) background_executor: BackgroundExecutor,
     pub(crate) foreground_executor: ForegroundExecutor,
 }
@@ -185,7 +186,7 @@ impl AsyncApp {
     {
         let mut cx = self.clone();
         self.foreground_executor
-            .spawn(async move { f(&mut cx).await })
+            .spawn_context(self.liveness_token.clone(), async move { f(&mut cx).await })
     }
 
     /// Determine whether global state of the specified type has been assigned.
@@ -334,7 +335,10 @@ impl AsyncWindowContext {
     {
         let mut cx = self.clone();
         self.foreground_executor
-            .spawn(async move { f(&mut cx).await })
+            .spawn_context(
+                self.app.liveness_token.clone(),
+                async move { f(&mut cx).await },
+            )
     }
 
     /// Present a platform dialog.

crates/gpui/src/app/test_context.rs 🔗

@@ -405,6 +405,7 @@ impl TestAppContext {
     pub fn to_async(&self) -> AsyncApp {
         AsyncApp {
             app: Rc::downgrade(&self.app),
+            liveness_token: std::sync::Arc::downgrade(&self.app.borrow().liveness),
             background_executor: self.background_executor.clone(),
             foreground_executor: self.foreground_executor.clone(),
         }

crates/gpui/src/executor.rs 🔗

@@ -125,6 +125,30 @@ impl<T> Task<T> {
             Task(TaskState::Spawned(task)) => task.detach(),
         }
     }
+
+    /// Converts this task into a fallible task that returns `Option<T>`.
+    ///
+    /// Unlike the standard `Task<T>`, a [`FallibleTask`] will return `None`
+    /// if the app was dropped while the task is executing.
+    ///
+    /// # Example
+    ///
+    /// ```ignore
+    /// // Background task that gracefully handles app shutdown:
+    /// cx.background_spawn(async move {
+    ///     let result = foreground_task.fallible().await;
+    ///     if let Some(value) = result {
+    ///         // Process the value
+    ///     }
+    ///     // If None, app was shut down - just exit gracefully
+    /// }).detach();
+    /// ```
+    pub fn fallible(self) -> FallibleTask<T> {
+        FallibleTask(match self.0 {
+            TaskState::Ready(val) => FallibleTaskState::Ready(val),
+            TaskState::Spawned(task) => FallibleTaskState::Spawned(task.fallible()),
+        })
+    }
 }
 
 impl<E, T> Task<Result<T, E>>
@@ -154,6 +178,55 @@ impl<T> Future for Task<T> {
     }
 }
 
+/// A task that returns `Option<T>` instead of panicking when cancelled.
+#[must_use]
+pub struct FallibleTask<T>(FallibleTaskState<T>);
+
+enum FallibleTaskState<T> {
+    /// A task that is ready to return a value
+    Ready(Option<T>),
+
+    /// A task that is currently running (wraps async_task::FallibleTask).
+    Spawned(async_task::FallibleTask<T, RunnableMeta>),
+}
+
+impl<T> FallibleTask<T> {
+    /// Creates a new fallible task that will resolve with the value.
+    pub fn ready(val: T) -> Self {
+        FallibleTask(FallibleTaskState::Ready(Some(val)))
+    }
+
+    /// Detaching a task runs it to completion in the background.
+    pub fn detach(self) {
+        match self.0 {
+            FallibleTaskState::Ready(_) => {}
+            FallibleTaskState::Spawned(task) => task.detach(),
+        }
+    }
+}
+
+impl<T> Future for FallibleTask<T> {
+    type Output = Option<T>;
+
+    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
+        match unsafe { self.get_unchecked_mut() } {
+            FallibleTask(FallibleTaskState::Ready(val)) => Poll::Ready(val.take()),
+            FallibleTask(FallibleTaskState::Spawned(task)) => Pin::new(task).poll(cx),
+        }
+    }
+}
+
+impl<T> std::fmt::Debug for FallibleTask<T> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        match &self.0 {
+            FallibleTaskState::Ready(_) => f.debug_tuple("FallibleTask::Ready").finish(),
+            FallibleTaskState::Spawned(task) => {
+                f.debug_tuple("FallibleTask::Spawned").field(task).finish()
+            }
+        }
+    }
+}
+
 /// A task label is an opaque identifier that you can use to
 /// refer to a task in tests.
 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
@@ -252,7 +325,10 @@ impl BackgroundExecutor {
 
         let (runnable, task) = unsafe {
             async_task::Builder::new()
-                .metadata(RunnableMeta { location })
+                .metadata(RunnableMeta {
+                    location,
+                    app: None,
+                })
                 .spawn_unchecked(
                     move |_| async {
                         let _notify_guard = NotifyOnDrop(pair);
@@ -330,7 +406,10 @@ impl BackgroundExecutor {
             );
 
             async_task::Builder::new()
-                .metadata(RunnableMeta { location })
+                .metadata(RunnableMeta {
+                    location,
+                    app: None,
+                })
                 .spawn(
                     move |_| future,
                     move |runnable| {
@@ -340,7 +419,10 @@ impl BackgroundExecutor {
         } else {
             let location = core::panic::Location::caller();
             async_task::Builder::new()
-                .metadata(RunnableMeta { location })
+                .metadata(RunnableMeta {
+                    location,
+                    app: None,
+                })
                 .spawn(
                     move |_| future,
                     move |runnable| {
@@ -566,7 +648,10 @@ impl BackgroundExecutor {
         }
         let location = core::panic::Location::caller();
         let (runnable, task) = async_task::Builder::new()
-            .metadata(RunnableMeta { location })
+            .metadata(RunnableMeta {
+                location,
+                app: None,
+            })
             .spawn(move |_| async move {}, {
                 let dispatcher = self.dispatcher.clone();
                 move |runnable| dispatcher.dispatch_after(duration, RunnableVariant::Meta(runnable))
@@ -681,7 +766,7 @@ impl ForegroundExecutor {
     where
         R: 'static,
     {
-        self.spawn_with_priority(Priority::default(), future)
+        self.inner_spawn(None, Priority::default(), future)
     }
 
     /// Enqueues the given Task to run on the main thread at some point in the future.
@@ -691,6 +776,31 @@ impl ForegroundExecutor {
         priority: Priority,
         future: impl Future<Output = R> + 'static,
     ) -> Task<R>
+    where
+        R: 'static,
+    {
+        self.inner_spawn(None, priority, future)
+    }
+
+    #[track_caller]
+    pub(crate) fn spawn_context<R>(
+        &self,
+        app: std::sync::Weak<()>,
+        future: impl Future<Output = R> + 'static,
+    ) -> Task<R>
+    where
+        R: 'static,
+    {
+        self.inner_spawn(Some(app), Priority::default(), future)
+    }
+
+    #[track_caller]
+    pub(crate) fn inner_spawn<R>(
+        &self,
+        app: Option<std::sync::Weak<()>>,
+        priority: Priority,
+        future: impl Future<Output = R> + 'static,
+    ) -> Task<R>
     where
         R: 'static,
     {
@@ -702,6 +812,7 @@ impl ForegroundExecutor {
             dispatcher: Arc<dyn PlatformDispatcher>,
             future: AnyLocalFuture<R>,
             location: &'static core::panic::Location<'static>,
+            app: Option<std::sync::Weak<()>>,
             priority: Priority,
         ) -> Task<R> {
             let (runnable, task) = spawn_local_with_source_location(
@@ -709,12 +820,12 @@ impl ForegroundExecutor {
                 move |runnable| {
                     dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable), priority)
                 },
-                RunnableMeta { location },
+                RunnableMeta { location, app },
             );
             runnable.schedule();
             Task(TaskState::Spawned(task))
         }
-        inner::<R>(dispatcher, Box::pin(future), location, priority)
+        inner::<R>(dispatcher, Box::pin(future), location, app, priority)
     }
 }
 
@@ -847,3 +958,259 @@ impl Drop for Scope<'_> {
         self.executor.block(self.rx.next());
     }
 }
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use crate::{App, TestDispatcher, TestPlatform};
+    use rand::SeedableRng;
+    use std::cell::RefCell;
+
+    #[test]
+    fn sanity_test_tasks_run() {
+        let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0));
+        let arc_dispatcher = Arc::new(dispatcher.clone());
+        let background_executor = BackgroundExecutor::new(arc_dispatcher.clone());
+        let foreground_executor = ForegroundExecutor::new(arc_dispatcher);
+
+        let platform = TestPlatform::new(background_executor, foreground_executor.clone());
+        let asset_source = Arc::new(());
+        let http_client = http_client::FakeHttpClient::with_404_response();
+
+        let app = App::new_app(platform, asset_source, http_client);
+        let liveness_token = std::sync::Arc::downgrade(&app.borrow().liveness);
+
+        let task_ran = Rc::new(RefCell::new(false));
+
+        foreground_executor
+            .spawn_context(liveness_token, {
+                let task_ran = Rc::clone(&task_ran);
+                async move {
+                    *task_ran.borrow_mut() = true;
+                }
+            })
+            .detach();
+
+        // Run dispatcher while app is still alive
+        dispatcher.run_until_parked();
+
+        // Task should have run
+        assert!(
+            *task_ran.borrow(),
+            "Task should run normally when app is alive"
+        );
+    }
+
+    #[test]
+    fn test_task_cancelled_when_app_dropped() {
+        let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0));
+        let arc_dispatcher = Arc::new(dispatcher.clone());
+        let background_executor = BackgroundExecutor::new(arc_dispatcher.clone());
+        let foreground_executor = ForegroundExecutor::new(arc_dispatcher);
+
+        let platform = TestPlatform::new(background_executor, foreground_executor.clone());
+        let asset_source = Arc::new(());
+        let http_client = http_client::FakeHttpClient::with_404_response();
+
+        let app = App::new_app(platform, asset_source, http_client);
+        let liveness_token = std::sync::Arc::downgrade(&app.borrow().liveness);
+        let app_weak = Rc::downgrade(&app);
+
+        let task_ran = Rc::new(RefCell::new(false));
+        let task_ran_clone = Rc::clone(&task_ran);
+
+        foreground_executor
+            .spawn_context(liveness_token, async move {
+                *task_ran_clone.borrow_mut() = true;
+            })
+            .detach();
+
+        drop(app);
+
+        assert!(app_weak.upgrade().is_none(), "App should have been dropped");
+
+        dispatcher.run_until_parked();
+
+        // The task should have been cancelled, not run
+        assert!(
+            !*task_ran.borrow(),
+            "Task should have been cancelled when app was dropped, but it ran!"
+        );
+    }
+
+    #[test]
+    fn test_nested_tasks_both_cancel() {
+        let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0));
+        let arc_dispatcher = Arc::new(dispatcher.clone());
+        let background_executor = BackgroundExecutor::new(arc_dispatcher.clone());
+        let foreground_executor = ForegroundExecutor::new(arc_dispatcher);
+
+        let platform = TestPlatform::new(background_executor, foreground_executor.clone());
+        let asset_source = Arc::new(());
+        let http_client = http_client::FakeHttpClient::with_404_response();
+
+        let app = App::new_app(platform, asset_source, http_client);
+        let liveness_token = std::sync::Arc::downgrade(&app.borrow().liveness);
+        let app_weak = Rc::downgrade(&app);
+
+        let outer_completed = Rc::new(RefCell::new(false));
+        let inner_completed = Rc::new(RefCell::new(false));
+        let reached_await = Rc::new(RefCell::new(false));
+
+        let outer_flag = Rc::clone(&outer_completed);
+        let inner_flag = Rc::clone(&inner_completed);
+        let await_flag = Rc::clone(&reached_await);
+
+        // Channel to block the inner task until we're ready
+        let (tx, rx) = futures::channel::oneshot::channel::<()>();
+
+        // We need clones of executor and liveness_token for the inner spawn
+        let inner_executor = foreground_executor.clone();
+        let inner_liveness_token = liveness_token.clone();
+
+        foreground_executor
+            .spawn_context(liveness_token, async move {
+                let inner_task = inner_executor.spawn_context(inner_liveness_token, {
+                    let inner_flag = Rc::clone(&inner_flag);
+                    async move {
+                        rx.await.ok();
+                        *inner_flag.borrow_mut() = true;
+                    }
+                });
+
+                *await_flag.borrow_mut() = true;
+
+                inner_task.await;
+
+                *outer_flag.borrow_mut() = true;
+            })
+            .detach();
+
+        // Run dispatcher until outer task reaches the await point
+        // The inner task will be blocked on the channel
+        dispatcher.run_until_parked();
+
+        // Verify we actually reached the await point before dropping the app
+        assert!(
+            *reached_await.borrow(),
+            "Outer task should have reached the await point"
+        );
+
+        // Neither task should have completed yet
+        assert!(
+            !*outer_completed.borrow(),
+            "Outer task should not have completed yet"
+        );
+        assert!(
+            !*inner_completed.borrow(),
+            "Inner task should not have completed yet"
+        );
+
+        // Drop the channel sender and app while outer is awaiting inner
+        drop(tx);
+        drop(app);
+        assert!(app_weak.upgrade().is_none(), "App should have been dropped");
+
+        // Run dispatcher - both tasks should be cancelled
+        dispatcher.run_until_parked();
+
+        // Neither task should have completed (both were cancelled)
+        assert!(
+            !*outer_completed.borrow(),
+            "Outer task should have been cancelled, not completed"
+        );
+        assert!(
+            !*inner_completed.borrow(),
+            "Inner task should have been cancelled, not completed"
+        );
+    }
+
+    #[test]
+    fn test_task_without_app_tracking_still_runs() {
+        let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0));
+        let arc_dispatcher = Arc::new(dispatcher.clone());
+        let background_executor = BackgroundExecutor::new(arc_dispatcher.clone());
+        let foreground_executor = ForegroundExecutor::new(arc_dispatcher);
+
+        let platform = TestPlatform::new(background_executor, foreground_executor.clone());
+        let asset_source = Arc::new(());
+        let http_client = http_client::FakeHttpClient::with_404_response();
+
+        let app = App::new_app(platform, asset_source, http_client);
+        let app_weak = Rc::downgrade(&app);
+
+        let task_ran = Rc::new(RefCell::new(false));
+        let task_ran_clone = Rc::clone(&task_ran);
+
+        let _task = foreground_executor.spawn(async move {
+            *task_ran_clone.borrow_mut() = true;
+        });
+
+        drop(app);
+
+        assert!(app_weak.upgrade().is_none(), "App should have been dropped");
+
+        dispatcher.run_until_parked();
+
+        assert!(
+            *task_ran.borrow(),
+            "Task without app tracking should still run after app is dropped"
+        );
+    }
+
+    #[test]
+    #[should_panic]
+    fn test_polling_cancelled_task_panics() {
+        let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0));
+        let arc_dispatcher = Arc::new(dispatcher.clone());
+        let background_executor = BackgroundExecutor::new(arc_dispatcher.clone());
+        let foreground_executor = ForegroundExecutor::new(arc_dispatcher);
+
+        let platform = TestPlatform::new(background_executor.clone(), foreground_executor.clone());
+        let asset_source = Arc::new(());
+        let http_client = http_client::FakeHttpClient::with_404_response();
+
+        let app = App::new_app(platform, asset_source, http_client);
+        let liveness_token = std::sync::Arc::downgrade(&app.borrow().liveness);
+        let app_weak = Rc::downgrade(&app);
+
+        let task = foreground_executor.spawn_context(liveness_token, async move { 42 });
+
+        drop(app);
+
+        assert!(app_weak.upgrade().is_none(), "App should have been dropped");
+
+        dispatcher.run_until_parked();
+
+        background_executor.block(task);
+    }
+
+    #[test]
+    fn test_polling_cancelled_task_returns_none_with_fallible() {
+        let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0));
+        let arc_dispatcher = Arc::new(dispatcher.clone());
+        let background_executor = BackgroundExecutor::new(arc_dispatcher.clone());
+        let foreground_executor = ForegroundExecutor::new(arc_dispatcher);
+
+        let platform = TestPlatform::new(background_executor.clone(), foreground_executor.clone());
+        let asset_source = Arc::new(());
+        let http_client = http_client::FakeHttpClient::with_404_response();
+
+        let app = App::new_app(platform, asset_source, http_client);
+        let liveness_token = std::sync::Arc::downgrade(&app.borrow().liveness);
+        let app_weak = Rc::downgrade(&app);
+
+        let task = foreground_executor
+            .spawn_context(liveness_token, async move { 42 })
+            .fallible();
+
+        drop(app);
+
+        assert!(app_weak.upgrade().is_none(), "App should have been dropped");
+
+        dispatcher.run_until_parked();
+
+        let result = background_executor.block(task);
+        assert_eq!(result, None, "Cancelled task should return None");
+    }
+}

crates/gpui/src/platform.rs 🔗

@@ -575,10 +575,30 @@ pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle {
 /// This type is public so that our test macro can generate and use it, but it should not
 /// be considered part of our public API.
 #[doc(hidden)]
-#[derive(Debug)]
 pub struct RunnableMeta {
     /// Location of the runnable
     pub location: &'static core::panic::Location<'static>,
+    /// Weak reference to check if the app is still alive before running this task
+    pub app: Option<std::sync::Weak<()>>,
+}
+
+impl std::fmt::Debug for RunnableMeta {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("RunnableMeta")
+            .field("location", &self.location)
+            .field("app_alive", &self.is_app_alive())
+            .finish()
+    }
+}
+
+impl RunnableMeta {
+    /// Returns true if the app is still alive (or if no app tracking is configured).
+    pub fn is_app_alive(&self) -> bool {
+        match &self.app {
+            Some(weak) => weak.strong_count() > 0,
+            None => true,
+        }
+    }
 }
 
 #[doc(hidden)]

crates/gpui/src/platform/mac/dispatcher.rs 🔗

@@ -251,7 +251,13 @@ extern "C" fn trampoline(runnable: *mut c_void) {
     let task =
         unsafe { Runnable::<RunnableMeta>::from_raw(NonNull::new_unchecked(runnable as *mut ())) };
 
-    let location = task.metadata().location;
+    let metadata = task.metadata();
+    let location = metadata.location;
+
+    if !metadata.is_app_alive() {
+        drop(task);
+        return;
+    }
 
     let start = Instant::now();
     let timing = TaskTiming {

crates/gpui/src/platform/test/dispatcher.rs 🔗

@@ -177,7 +177,14 @@ impl TestDispatcher {
 
         // todo(localcc): add timings to tests
         match runnable {
-            RunnableVariant::Meta(runnable) => runnable.run(),
+            RunnableVariant::Meta(runnable) => {
+                if !runnable.metadata().is_app_alive() {
+                    drop(runnable);
+                    self.state.lock().is_main_thread = was_main_thread;
+                    return true;
+                }
+                runnable.run()
+            }
             RunnableVariant::Compat(runnable) => runnable.run(),
         };