diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index 96f815ac0b592600f22b3c9b9686571487ff77a2..c99ee073b85d86c5971f3274a4d149895c821b4e 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -580,6 +580,7 @@ impl GpuiMode { /// You need a reference to an `App` to access the state of a [Entity]. pub struct App { pub(crate) this: Weak, + pub(crate) liveness: std::sync::Arc<()>, pub(crate) platform: Rc, pub(crate) mode: GpuiMode, text_system: Arc, @@ -658,6 +659,7 @@ impl App { let app = Rc::new_cyclic(|this| AppCell { app: RefCell::new(App { this: this.clone(), + liveness: std::sync::Arc::new(()), platform: platform.clone(), text_system, mode: GpuiMode::Production, @@ -1476,6 +1478,7 @@ impl App { pub fn to_async(&self) -> AsyncApp { AsyncApp { app: self.this.clone(), + liveness_token: std::sync::Arc::downgrade(&self.liveness), background_executor: self.background_executor.clone(), foreground_executor: self.foreground_executor.clone(), } diff --git a/crates/gpui/src/app/async_context.rs b/crates/gpui/src/app/async_context.rs index 805dfced162cd27f0cc785a8282ae3b802c2873a..3f9d21ef65ab16350c8d16fa3197ac507e5d503e 100644 --- a/crates/gpui/src/app/async_context.rs +++ b/crates/gpui/src/app/async_context.rs @@ -16,6 +16,7 @@ use super::{Context, WeakEntity}; #[derive(Clone)] pub struct AsyncApp { pub(crate) app: Weak, + pub(crate) liveness_token: std::sync::Weak<()>, pub(crate) background_executor: BackgroundExecutor, pub(crate) foreground_executor: ForegroundExecutor, } @@ -185,7 +186,7 @@ impl AsyncApp { { let mut cx = self.clone(); self.foreground_executor - .spawn(async move { f(&mut cx).await }) + .spawn_context(self.liveness_token.clone(), async move { f(&mut cx).await }) } /// Determine whether global state of the specified type has been assigned. @@ -334,7 +335,10 @@ impl AsyncWindowContext { { let mut cx = self.clone(); self.foreground_executor - .spawn(async move { f(&mut cx).await }) + .spawn_context( + self.app.liveness_token.clone(), + async move { f(&mut cx).await }, + ) } /// Present a platform dialog. diff --git a/crates/gpui/src/app/test_context.rs b/crates/gpui/src/app/test_context.rs index 9b982f9a1ca3c14b99dfc93e938aafe4e2f75cff..d7a1e4704584595d769989b421ccd767fba58985 100644 --- a/crates/gpui/src/app/test_context.rs +++ b/crates/gpui/src/app/test_context.rs @@ -405,6 +405,7 @@ impl TestAppContext { pub fn to_async(&self) -> AsyncApp { AsyncApp { app: Rc::downgrade(&self.app), + liveness_token: std::sync::Arc::downgrade(&self.app.borrow().liveness), background_executor: self.background_executor.clone(), foreground_executor: self.foreground_executor.clone(), } diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 6c2ecb341ff2fe446efd7823c107fd32a557feb5..eb16cbd9a0bce1cc2444167cc793e1c8d55b7053 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -125,6 +125,30 @@ impl Task { Task(TaskState::Spawned(task)) => task.detach(), } } + + /// Converts this task into a fallible task that returns `Option`. + /// + /// Unlike the standard `Task`, a [`FallibleTask`] will return `None` + /// if the app was dropped while the task is executing. + /// + /// # Example + /// + /// ```ignore + /// // Background task that gracefully handles app shutdown: + /// cx.background_spawn(async move { + /// let result = foreground_task.fallible().await; + /// if let Some(value) = result { + /// // Process the value + /// } + /// // If None, app was shut down - just exit gracefully + /// }).detach(); + /// ``` + pub fn fallible(self) -> FallibleTask { + FallibleTask(match self.0 { + TaskState::Ready(val) => FallibleTaskState::Ready(val), + TaskState::Spawned(task) => FallibleTaskState::Spawned(task.fallible()), + }) + } } impl Task> @@ -154,6 +178,55 @@ impl Future for Task { } } +/// A task that returns `Option` instead of panicking when cancelled. +#[must_use] +pub struct FallibleTask(FallibleTaskState); + +enum FallibleTaskState { + /// A task that is ready to return a value + Ready(Option), + + /// A task that is currently running (wraps async_task::FallibleTask). + Spawned(async_task::FallibleTask), +} + +impl FallibleTask { + /// Creates a new fallible task that will resolve with the value. + pub fn ready(val: T) -> Self { + FallibleTask(FallibleTaskState::Ready(Some(val))) + } + + /// Detaching a task runs it to completion in the background. + pub fn detach(self) { + match self.0 { + FallibleTaskState::Ready(_) => {} + FallibleTaskState::Spawned(task) => task.detach(), + } + } +} + +impl Future for FallibleTask { + type Output = Option; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + match unsafe { self.get_unchecked_mut() } { + FallibleTask(FallibleTaskState::Ready(val)) => Poll::Ready(val.take()), + FallibleTask(FallibleTaskState::Spawned(task)) => Pin::new(task).poll(cx), + } + } +} + +impl std::fmt::Debug for FallibleTask { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.0 { + FallibleTaskState::Ready(_) => f.debug_tuple("FallibleTask::Ready").finish(), + FallibleTaskState::Spawned(task) => { + f.debug_tuple("FallibleTask::Spawned").field(task).finish() + } + } + } +} + /// A task label is an opaque identifier that you can use to /// refer to a task in tests. #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] @@ -252,7 +325,10 @@ impl BackgroundExecutor { let (runnable, task) = unsafe { async_task::Builder::new() - .metadata(RunnableMeta { location }) + .metadata(RunnableMeta { + location, + app: None, + }) .spawn_unchecked( move |_| async { let _notify_guard = NotifyOnDrop(pair); @@ -330,7 +406,10 @@ impl BackgroundExecutor { ); async_task::Builder::new() - .metadata(RunnableMeta { location }) + .metadata(RunnableMeta { + location, + app: None, + }) .spawn( move |_| future, move |runnable| { @@ -340,7 +419,10 @@ impl BackgroundExecutor { } else { let location = core::panic::Location::caller(); async_task::Builder::new() - .metadata(RunnableMeta { location }) + .metadata(RunnableMeta { + location, + app: None, + }) .spawn( move |_| future, move |runnable| { @@ -566,7 +648,10 @@ impl BackgroundExecutor { } let location = core::panic::Location::caller(); let (runnable, task) = async_task::Builder::new() - .metadata(RunnableMeta { location }) + .metadata(RunnableMeta { + location, + app: None, + }) .spawn(move |_| async move {}, { let dispatcher = self.dispatcher.clone(); move |runnable| dispatcher.dispatch_after(duration, RunnableVariant::Meta(runnable)) @@ -681,7 +766,7 @@ impl ForegroundExecutor { where R: 'static, { - self.spawn_with_priority(Priority::default(), future) + self.inner_spawn(None, Priority::default(), future) } /// Enqueues the given Task to run on the main thread at some point in the future. @@ -691,6 +776,31 @@ impl ForegroundExecutor { priority: Priority, future: impl Future + 'static, ) -> Task + where + R: 'static, + { + self.inner_spawn(None, priority, future) + } + + #[track_caller] + pub(crate) fn spawn_context( + &self, + app: std::sync::Weak<()>, + future: impl Future + 'static, + ) -> Task + where + R: 'static, + { + self.inner_spawn(Some(app), Priority::default(), future) + } + + #[track_caller] + pub(crate) fn inner_spawn( + &self, + app: Option>, + priority: Priority, + future: impl Future + 'static, + ) -> Task where R: 'static, { @@ -702,6 +812,7 @@ impl ForegroundExecutor { dispatcher: Arc, future: AnyLocalFuture, location: &'static core::panic::Location<'static>, + app: Option>, priority: Priority, ) -> Task { let (runnable, task) = spawn_local_with_source_location( @@ -709,12 +820,12 @@ impl ForegroundExecutor { move |runnable| { dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable), priority) }, - RunnableMeta { location }, + RunnableMeta { location, app }, ); runnable.schedule(); Task(TaskState::Spawned(task)) } - inner::(dispatcher, Box::pin(future), location, priority) + inner::(dispatcher, Box::pin(future), location, app, priority) } } @@ -847,3 +958,259 @@ impl Drop for Scope<'_> { self.executor.block(self.rx.next()); } } + +#[cfg(test)] +mod test { + use super::*; + use crate::{App, TestDispatcher, TestPlatform}; + use rand::SeedableRng; + use std::cell::RefCell; + + #[test] + fn sanity_test_tasks_run() { + let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0)); + let arc_dispatcher = Arc::new(dispatcher.clone()); + let background_executor = BackgroundExecutor::new(arc_dispatcher.clone()); + let foreground_executor = ForegroundExecutor::new(arc_dispatcher); + + let platform = TestPlatform::new(background_executor, foreground_executor.clone()); + let asset_source = Arc::new(()); + let http_client = http_client::FakeHttpClient::with_404_response(); + + let app = App::new_app(platform, asset_source, http_client); + let liveness_token = std::sync::Arc::downgrade(&app.borrow().liveness); + + let task_ran = Rc::new(RefCell::new(false)); + + foreground_executor + .spawn_context(liveness_token, { + let task_ran = Rc::clone(&task_ran); + async move { + *task_ran.borrow_mut() = true; + } + }) + .detach(); + + // Run dispatcher while app is still alive + dispatcher.run_until_parked(); + + // Task should have run + assert!( + *task_ran.borrow(), + "Task should run normally when app is alive" + ); + } + + #[test] + fn test_task_cancelled_when_app_dropped() { + let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0)); + let arc_dispatcher = Arc::new(dispatcher.clone()); + let background_executor = BackgroundExecutor::new(arc_dispatcher.clone()); + let foreground_executor = ForegroundExecutor::new(arc_dispatcher); + + let platform = TestPlatform::new(background_executor, foreground_executor.clone()); + let asset_source = Arc::new(()); + let http_client = http_client::FakeHttpClient::with_404_response(); + + let app = App::new_app(platform, asset_source, http_client); + let liveness_token = std::sync::Arc::downgrade(&app.borrow().liveness); + let app_weak = Rc::downgrade(&app); + + let task_ran = Rc::new(RefCell::new(false)); + let task_ran_clone = Rc::clone(&task_ran); + + foreground_executor + .spawn_context(liveness_token, async move { + *task_ran_clone.borrow_mut() = true; + }) + .detach(); + + drop(app); + + assert!(app_weak.upgrade().is_none(), "App should have been dropped"); + + dispatcher.run_until_parked(); + + // The task should have been cancelled, not run + assert!( + !*task_ran.borrow(), + "Task should have been cancelled when app was dropped, but it ran!" + ); + } + + #[test] + fn test_nested_tasks_both_cancel() { + let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0)); + let arc_dispatcher = Arc::new(dispatcher.clone()); + let background_executor = BackgroundExecutor::new(arc_dispatcher.clone()); + let foreground_executor = ForegroundExecutor::new(arc_dispatcher); + + let platform = TestPlatform::new(background_executor, foreground_executor.clone()); + let asset_source = Arc::new(()); + let http_client = http_client::FakeHttpClient::with_404_response(); + + let app = App::new_app(platform, asset_source, http_client); + let liveness_token = std::sync::Arc::downgrade(&app.borrow().liveness); + let app_weak = Rc::downgrade(&app); + + let outer_completed = Rc::new(RefCell::new(false)); + let inner_completed = Rc::new(RefCell::new(false)); + let reached_await = Rc::new(RefCell::new(false)); + + let outer_flag = Rc::clone(&outer_completed); + let inner_flag = Rc::clone(&inner_completed); + let await_flag = Rc::clone(&reached_await); + + // Channel to block the inner task until we're ready + let (tx, rx) = futures::channel::oneshot::channel::<()>(); + + // We need clones of executor and liveness_token for the inner spawn + let inner_executor = foreground_executor.clone(); + let inner_liveness_token = liveness_token.clone(); + + foreground_executor + .spawn_context(liveness_token, async move { + let inner_task = inner_executor.spawn_context(inner_liveness_token, { + let inner_flag = Rc::clone(&inner_flag); + async move { + rx.await.ok(); + *inner_flag.borrow_mut() = true; + } + }); + + *await_flag.borrow_mut() = true; + + inner_task.await; + + *outer_flag.borrow_mut() = true; + }) + .detach(); + + // Run dispatcher until outer task reaches the await point + // The inner task will be blocked on the channel + dispatcher.run_until_parked(); + + // Verify we actually reached the await point before dropping the app + assert!( + *reached_await.borrow(), + "Outer task should have reached the await point" + ); + + // Neither task should have completed yet + assert!( + !*outer_completed.borrow(), + "Outer task should not have completed yet" + ); + assert!( + !*inner_completed.borrow(), + "Inner task should not have completed yet" + ); + + // Drop the channel sender and app while outer is awaiting inner + drop(tx); + drop(app); + assert!(app_weak.upgrade().is_none(), "App should have been dropped"); + + // Run dispatcher - both tasks should be cancelled + dispatcher.run_until_parked(); + + // Neither task should have completed (both were cancelled) + assert!( + !*outer_completed.borrow(), + "Outer task should have been cancelled, not completed" + ); + assert!( + !*inner_completed.borrow(), + "Inner task should have been cancelled, not completed" + ); + } + + #[test] + fn test_task_without_app_tracking_still_runs() { + let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0)); + let arc_dispatcher = Arc::new(dispatcher.clone()); + let background_executor = BackgroundExecutor::new(arc_dispatcher.clone()); + let foreground_executor = ForegroundExecutor::new(arc_dispatcher); + + let platform = TestPlatform::new(background_executor, foreground_executor.clone()); + let asset_source = Arc::new(()); + let http_client = http_client::FakeHttpClient::with_404_response(); + + let app = App::new_app(platform, asset_source, http_client); + let app_weak = Rc::downgrade(&app); + + let task_ran = Rc::new(RefCell::new(false)); + let task_ran_clone = Rc::clone(&task_ran); + + let _task = foreground_executor.spawn(async move { + *task_ran_clone.borrow_mut() = true; + }); + + drop(app); + + assert!(app_weak.upgrade().is_none(), "App should have been dropped"); + + dispatcher.run_until_parked(); + + assert!( + *task_ran.borrow(), + "Task without app tracking should still run after app is dropped" + ); + } + + #[test] + #[should_panic] + fn test_polling_cancelled_task_panics() { + let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0)); + let arc_dispatcher = Arc::new(dispatcher.clone()); + let background_executor = BackgroundExecutor::new(arc_dispatcher.clone()); + let foreground_executor = ForegroundExecutor::new(arc_dispatcher); + + let platform = TestPlatform::new(background_executor.clone(), foreground_executor.clone()); + let asset_source = Arc::new(()); + let http_client = http_client::FakeHttpClient::with_404_response(); + + let app = App::new_app(platform, asset_source, http_client); + let liveness_token = std::sync::Arc::downgrade(&app.borrow().liveness); + let app_weak = Rc::downgrade(&app); + + let task = foreground_executor.spawn_context(liveness_token, async move { 42 }); + + drop(app); + + assert!(app_weak.upgrade().is_none(), "App should have been dropped"); + + dispatcher.run_until_parked(); + + background_executor.block(task); + } + + #[test] + fn test_polling_cancelled_task_returns_none_with_fallible() { + let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0)); + let arc_dispatcher = Arc::new(dispatcher.clone()); + let background_executor = BackgroundExecutor::new(arc_dispatcher.clone()); + let foreground_executor = ForegroundExecutor::new(arc_dispatcher); + + let platform = TestPlatform::new(background_executor.clone(), foreground_executor.clone()); + let asset_source = Arc::new(()); + let http_client = http_client::FakeHttpClient::with_404_response(); + + let app = App::new_app(platform, asset_source, http_client); + let liveness_token = std::sync::Arc::downgrade(&app.borrow().liveness); + let app_weak = Rc::downgrade(&app); + + let task = foreground_executor + .spawn_context(liveness_token, async move { 42 }) + .fallible(); + + drop(app); + + assert!(app_weak.upgrade().is_none(), "App should have been dropped"); + + dispatcher.run_until_parked(); + + let result = background_executor.block(task); + assert_eq!(result, None, "Cancelled task should return None"); + } +} diff --git a/crates/gpui/src/platform.rs b/crates/gpui/src/platform.rs index 112775890ef6e478f0b2d347bc9c9ae56dac3c73..e2a06eb77d07a0baeeb9e66ebe105b34b6073a88 100644 --- a/crates/gpui/src/platform.rs +++ b/crates/gpui/src/platform.rs @@ -575,10 +575,30 @@ pub(crate) trait PlatformWindow: HasWindowHandle + HasDisplayHandle { /// This type is public so that our test macro can generate and use it, but it should not /// be considered part of our public API. #[doc(hidden)] -#[derive(Debug)] pub struct RunnableMeta { /// Location of the runnable pub location: &'static core::panic::Location<'static>, + /// Weak reference to check if the app is still alive before running this task + pub app: Option>, +} + +impl std::fmt::Debug for RunnableMeta { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RunnableMeta") + .field("location", &self.location) + .field("app_alive", &self.is_app_alive()) + .finish() + } +} + +impl RunnableMeta { + /// Returns true if the app is still alive (or if no app tracking is configured). + pub fn is_app_alive(&self) -> bool { + match &self.app { + Some(weak) => weak.strong_count() > 0, + None => true, + } + } } #[doc(hidden)] diff --git a/crates/gpui/src/platform/mac/dispatcher.rs b/crates/gpui/src/platform/mac/dispatcher.rs index 1dfea82d58cbf2387571cabdcd7fbcfcf785c735..c10998ed3da49d1837929ab2ff93d19b035812df 100644 --- a/crates/gpui/src/platform/mac/dispatcher.rs +++ b/crates/gpui/src/platform/mac/dispatcher.rs @@ -251,7 +251,13 @@ extern "C" fn trampoline(runnable: *mut c_void) { let task = unsafe { Runnable::::from_raw(NonNull::new_unchecked(runnable as *mut ())) }; - let location = task.metadata().location; + let metadata = task.metadata(); + let location = metadata.location; + + if !metadata.is_app_alive() { + drop(task); + return; + } let start = Instant::now(); let timing = TaskTiming { diff --git a/crates/gpui/src/platform/test/dispatcher.rs b/crates/gpui/src/platform/test/dispatcher.rs index c271430586106abc93e0bb3258c9e25a06b12383..8ff761cb64d71a7830f2760b682ea7a84759b9bb 100644 --- a/crates/gpui/src/platform/test/dispatcher.rs +++ b/crates/gpui/src/platform/test/dispatcher.rs @@ -177,7 +177,14 @@ impl TestDispatcher { // todo(localcc): add timings to tests match runnable { - RunnableVariant::Meta(runnable) => runnable.run(), + RunnableVariant::Meta(runnable) => { + if !runnable.metadata().is_app_alive() { + drop(runnable); + self.state.lock().is_main_thread = was_main_thread; + return true; + } + runnable.run() + } RunnableVariant::Compat(runnable) => runnable.run(), };