Start work on detecting leaked handles in tests

Max Brunsfeld and Nathan Sobo created

For now, just track models. Tests fail because we don't
yet clear the app contexts at the right time.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/gpui/src/app.rs                | 189 ++++++++++++++++++++++------
crates/gpui/src/executor.rs           |  37 -----
crates/gpui/src/test.rs               |  20 +-
crates/gpui/src/util.rs               |  27 +++
crates/gpui_macros/src/gpui_macros.rs |   1 
crates/server/src/rpc.rs              |  11 +
6 files changed, 196 insertions(+), 89 deletions(-)

Detailed changes

crates/gpui/src/app.rs 🔗

@@ -4,10 +4,11 @@ use crate::{
     keymap::{self, Keystroke},
     platform::{self, CursorStyle, Platform, PromptLevel, WindowOptions},
     presenter::Presenter,
-    util::{post_inc, timeout},
+    util::{post_inc, timeout, CwdBacktrace},
     AssetCache, AssetSource, ClipboardItem, FontCache, PathPromptOptions, TextLayoutCache,
 };
 use anyhow::{anyhow, Result};
+use backtrace::Backtrace;
 use keymap::MatchResult;
 use parking_lot::Mutex;
 use platform::Event;
@@ -235,7 +236,6 @@ pub struct App(Rc<RefCell<MutableAppContext>>);
 #[derive(Clone)]
 pub struct AsyncAppContext(Rc<RefCell<MutableAppContext>>);
 
-#[derive(Clone)]
 pub struct TestAppContext {
     cx: Rc<RefCell<MutableAppContext>>,
     foreground_platform: Rc<platform::test::ForegroundPlatform>,
@@ -252,6 +252,7 @@ impl App {
             platform.clone(),
             foreground_platform.clone(),
             Arc::new(FontCache::new(platform.fonts())),
+            Default::default(),
             asset_source,
         ))));
 
@@ -389,6 +390,7 @@ impl TestAppContext {
         foreground: Rc<executor::Foreground>,
         background: Arc<executor::Background>,
         font_cache: Arc<FontCache>,
+        leak_detector: Arc<Mutex<LeakDetector>>,
         first_entity_id: usize,
     ) -> Self {
         let mut cx = MutableAppContext::new(
@@ -397,6 +399,11 @@ impl TestAppContext {
             platform,
             foreground_platform.clone(),
             font_cache,
+            RefCounts {
+                #[cfg(feature = "test-support")]
+                leak_detector,
+                ..Default::default()
+            },
             (),
         );
         cx.next_entity_id = first_entity_id;
@@ -551,6 +558,11 @@ impl TestAppContext {
             .expect("prompt was not called");
         let _ = done_tx.try_send(answer);
     }
+
+    #[cfg(feature = "test-support")]
+    pub fn leak_detector(&self) -> Arc<Mutex<LeakDetector>> {
+        self.cx.borrow().leak_detector()
+    }
 }
 
 impl AsyncAppContext {
@@ -758,8 +770,8 @@ impl MutableAppContext {
         platform: Arc<dyn platform::Platform>,
         foreground_platform: Rc<dyn platform::ForegroundPlatform>,
         font_cache: Arc<FontCache>,
+        ref_counts: RefCounts,
         asset_source: impl AssetSource,
-        // entity_drop_tx:
     ) -> Self {
         Self {
             weak_self: None,
@@ -771,7 +783,7 @@ impl MutableAppContext {
                 windows: Default::default(),
                 app_states: Default::default(),
                 element_states: Default::default(),
-                ref_counts: Arc::new(Mutex::new(RefCounts::default())),
+                ref_counts: Arc::new(Mutex::new(ref_counts)),
                 background,
                 font_cache,
                 platform,
@@ -1808,6 +1820,11 @@ impl MutableAppContext {
     pub fn read_from_clipboard(&self) -> Option<ClipboardItem> {
         self.cx.platform.read_from_clipboard()
     }
+
+    #[cfg(feature = "test-support")]
+    pub fn leak_detector(&self) -> Arc<Mutex<LeakDetector>> {
+        self.cx.ref_counts.lock().leak_detector.clone()
+    }
 }
 
 impl ReadModel for MutableAppContext {
@@ -2003,12 +2020,11 @@ impl UpgradeModelHandle for AppContext {
 
     fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option<AnyModelHandle> {
         if self.models.contains_key(&handle.model_id) {
-            self.ref_counts.lock().inc_model(handle.model_id);
-            Some(AnyModelHandle {
-                model_id: handle.model_id,
-                model_type: handle.model_type,
-                ref_counts: self.ref_counts.clone(),
-            })
+            Some(AnyModelHandle::new(
+                handle.model_id,
+                handle.model_type,
+                self.ref_counts.clone(),
+            ))
         } else {
             None
         }
@@ -2814,19 +2830,33 @@ pub enum EntityLocation {
     View(usize, usize),
 }
 
-pub struct ModelHandle<T> {
+pub struct ModelHandle<T: Entity> {
     model_id: usize,
     model_type: PhantomData<T>,
     ref_counts: Arc<Mutex<RefCounts>>,
+
+    #[cfg(feature = "test-support")]
+    handle_id: usize,
 }
 
 impl<T: Entity> ModelHandle<T> {
     fn new(model_id: usize, ref_counts: &Arc<Mutex<RefCounts>>) -> Self {
         ref_counts.lock().inc_model(model_id);
+
+        #[cfg(feature = "test-support")]
+        let handle_id = ref_counts
+            .lock()
+            .leak_detector
+            .lock()
+            .handle_created(Some(type_name::<T>()), model_id);
+
         Self {
             model_id,
             model_type: PhantomData,
             ref_counts: ref_counts.clone(),
+
+            #[cfg(feature = "test-support")]
+            handle_id,
         }
     }
 
@@ -2975,44 +3005,39 @@ impl<T: Entity> ModelHandle<T> {
     }
 }
 
-impl<T> Clone for ModelHandle<T> {
+impl<T: Entity> Clone for ModelHandle<T> {
     fn clone(&self) -> Self {
-        self.ref_counts.lock().inc_model(self.model_id);
-        Self {
-            model_id: self.model_id,
-            model_type: PhantomData,
-            ref_counts: self.ref_counts.clone(),
-        }
+        Self::new(self.model_id, &self.ref_counts)
     }
 }
 
-impl<T> PartialEq for ModelHandle<T> {
+impl<T: Entity> PartialEq for ModelHandle<T> {
     fn eq(&self, other: &Self) -> bool {
         self.model_id == other.model_id
     }
 }
 
-impl<T> Eq for ModelHandle<T> {}
+impl<T: Entity> Eq for ModelHandle<T> {}
 
-impl<T> PartialEq<WeakModelHandle<T>> for ModelHandle<T> {
+impl<T: Entity> PartialEq<WeakModelHandle<T>> for ModelHandle<T> {
     fn eq(&self, other: &WeakModelHandle<T>) -> bool {
         self.model_id == other.model_id
     }
 }
 
-impl<T> Hash for ModelHandle<T> {
+impl<T: Entity> Hash for ModelHandle<T> {
     fn hash<H: Hasher>(&self, state: &mut H) {
         self.model_id.hash(state);
     }
 }
 
-impl<T> std::borrow::Borrow<usize> for ModelHandle<T> {
+impl<T: Entity> std::borrow::Borrow<usize> for ModelHandle<T> {
     fn borrow(&self) -> &usize {
         &self.model_id
     }
 }
 
-impl<T> Debug for ModelHandle<T> {
+impl<T: Entity> Debug for ModelHandle<T> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         f.debug_tuple(&format!("ModelHandle<{}>", type_name::<T>()))
             .field(&self.model_id)
@@ -3020,12 +3045,19 @@ impl<T> Debug for ModelHandle<T> {
     }
 }
 
-unsafe impl<T> Send for ModelHandle<T> {}
-unsafe impl<T> Sync for ModelHandle<T> {}
+unsafe impl<T: Entity> Send for ModelHandle<T> {}
+unsafe impl<T: Entity> Sync for ModelHandle<T> {}
 
-impl<T> Drop for ModelHandle<T> {
+impl<T: Entity> Drop for ModelHandle<T> {
     fn drop(&mut self) {
-        self.ref_counts.lock().dec_model(self.model_id);
+        let mut ref_counts = self.ref_counts.lock();
+        ref_counts.dec_model(self.model_id);
+
+        #[cfg(feature = "test-support")]
+        ref_counts
+            .leak_detector
+            .lock()
+            .handle_dropped(self.model_id, self.handle_id);
     }
 }
 
@@ -3431,15 +3463,41 @@ pub struct AnyModelHandle {
     model_id: usize,
     model_type: TypeId,
     ref_counts: Arc<Mutex<RefCounts>>,
+
+    #[cfg(feature = "test-support")]
+    handle_id: usize,
 }
 
 impl AnyModelHandle {
+    fn new(model_id: usize, model_type: TypeId, ref_counts: Arc<Mutex<RefCounts>>) -> Self {
+        ref_counts.lock().inc_model(model_id);
+
+        #[cfg(feature = "test-support")]
+        let handle_id = ref_counts
+            .lock()
+            .leak_detector
+            .lock()
+            .handle_created(None, model_id);
+
+        Self {
+            model_id,
+            model_type,
+            ref_counts,
+
+            #[cfg(feature = "test-support")]
+            handle_id,
+        }
+    }
+
     pub fn downcast<T: Entity>(self) -> Option<ModelHandle<T>> {
         if self.is::<T>() {
             let result = Some(ModelHandle {
                 model_id: self.model_id,
                 model_type: PhantomData,
                 ref_counts: self.ref_counts.clone(),
+
+                #[cfg(feature = "test-support")]
+                handle_id: self.handle_id,
             });
             unsafe {
                 Arc::decrement_strong_count(&self.ref_counts);
@@ -3465,29 +3523,30 @@ impl AnyModelHandle {
 
 impl<T: Entity> From<ModelHandle<T>> for AnyModelHandle {
     fn from(handle: ModelHandle<T>) -> Self {
-        handle.ref_counts.lock().inc_model(handle.model_id);
-        Self {
-            model_id: handle.model_id,
-            model_type: TypeId::of::<T>(),
-            ref_counts: handle.ref_counts.clone(),
-        }
+        Self::new(
+            handle.model_id,
+            TypeId::of::<T>(),
+            handle.ref_counts.clone(),
+        )
     }
 }
 
 impl Clone for AnyModelHandle {
     fn clone(&self) -> Self {
-        self.ref_counts.lock().inc_model(self.model_id);
-        Self {
-            model_id: self.model_id,
-            model_type: self.model_type,
-            ref_counts: self.ref_counts.clone(),
-        }
+        Self::new(self.model_id, self.model_type, self.ref_counts.clone())
     }
 }
 
 impl Drop for AnyModelHandle {
     fn drop(&mut self) {
-        self.ref_counts.lock().dec_model(self.model_id);
+        let mut ref_counts = self.ref_counts.lock();
+        ref_counts.dec_model(self.model_id);
+
+        #[cfg(feature = "test-support")]
+        ref_counts
+            .leak_detector
+            .lock()
+            .handle_dropped(self.model_id, self.handle_id);
     }
 }
 
@@ -3694,6 +3753,51 @@ impl Drop for Subscription {
     }
 }
 
+#[derive(Default)]
+pub struct LeakDetector {
+    next_handle_id: usize,
+    handle_backtraces: HashMap<usize, (Option<&'static str>, HashMap<usize, Backtrace>)>,
+}
+
+impl LeakDetector {
+    fn handle_created(&mut self, type_name: Option<&'static str>, entity_id: usize) -> usize {
+        let handle_id = post_inc(&mut self.next_handle_id);
+        let entry = self.handle_backtraces.entry(entity_id).or_default();
+        if let Some(type_name) = type_name {
+            entry.0.get_or_insert(type_name);
+        }
+        entry.1.insert(handle_id, Backtrace::new_unresolved());
+        handle_id
+    }
+
+    fn handle_dropped(&mut self, entity_id: usize, handle_id: usize) {
+        if let Some((_, backtraces)) = self.handle_backtraces.get_mut(&entity_id) {
+            assert!(backtraces.remove(&handle_id).is_some());
+            if backtraces.is_empty() {
+                self.handle_backtraces.remove(&entity_id);
+            }
+        }
+    }
+
+    pub fn detect(&mut self) {
+        let mut found_leaks = false;
+        for (id, (type_name, backtraces)) in self.handle_backtraces.iter_mut() {
+            eprintln!(
+                "leaked {} handles to {:?} {}",
+                backtraces.len(),
+                type_name.unwrap_or("entity"),
+                id
+            );
+            for trace in backtraces.values_mut() {
+                trace.resolve();
+                eprintln!("{:?}", CwdBacktrace(trace));
+            }
+            found_leaks = true;
+        }
+        assert!(!found_leaks, "detected leaked handles");
+    }
+}
+
 #[derive(Default)]
 struct RefCounts {
     entity_counts: HashMap<usize, usize>,
@@ -3701,6 +3805,9 @@ struct RefCounts {
     dropped_models: HashSet<usize>,
     dropped_views: HashSet<(usize, usize)>,
     dropped_element_states: HashSet<ElementStateId>,
+
+    #[cfg(feature = "test-support")]
+    leak_detector: Arc<Mutex<LeakDetector>>,
 }
 
 struct ElementStateRefCount {

crates/gpui/src/executor.rs 🔗

@@ -1,6 +1,6 @@
 use anyhow::{anyhow, Result};
 use async_task::Runnable;
-use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
+use backtrace::Backtrace;
 use collections::HashMap;
 use parking_lot::Mutex;
 use postage::{barrier, prelude::Stream as _};
@@ -8,7 +8,7 @@ use rand::prelude::*;
 use smol::{channel, future::yield_now, prelude::*, Executor, Timer};
 use std::{
     any::Any,
-    fmt::{self, Debug, Display},
+    fmt::{self, Display},
     marker::PhantomData,
     mem,
     ops::RangeInclusive,
@@ -282,7 +282,7 @@ impl DeterministicState {
                 backtrace.resolve();
                 backtrace_message = format!(
                     "\nbacktrace of waiting future:\n{:?}",
-                    CwdBacktrace::new(backtrace)
+                    util::CwdBacktrace(backtrace)
                 );
             }
 
@@ -294,37 +294,6 @@ impl DeterministicState {
     }
 }
 
-struct CwdBacktrace<'a> {
-    backtrace: &'a Backtrace,
-}
-
-impl<'a> CwdBacktrace<'a> {
-    fn new(backtrace: &'a Backtrace) -> Self {
-        Self { backtrace }
-    }
-}
-
-impl<'a> Debug for CwdBacktrace<'a> {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
-        let cwd = std::env::current_dir().unwrap();
-        let mut print_path = |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
-            fmt::Display::fmt(&path, fmt)
-        };
-        let mut fmt = BacktraceFmt::new(f, backtrace::PrintFmt::Full, &mut print_path);
-        for frame in self.backtrace.frames() {
-            let mut formatted_frame = fmt.frame();
-            if frame
-                .symbols()
-                .iter()
-                .any(|s| s.filename().map_or(false, |f| f.starts_with(&cwd)))
-            {
-                formatted_frame.backtrace_frame(frame)?;
-            }
-        }
-        fmt.finish()
-    }
-}
-
 impl Foreground {
     pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
         if dispatcher.is_main_thread() {

crates/gpui/src/test.rs 🔗

@@ -1,3 +1,10 @@
+use crate::{
+    executor, platform, Entity, FontCache, Handle, LeakDetector, MutableAppContext, Platform,
+    Subscription, TestAppContext,
+};
+use futures::StreamExt;
+use parking_lot::Mutex;
+use smol::channel;
 use std::{
     panic::{self, RefUnwindSafe},
     rc::Rc,
@@ -7,14 +14,6 @@ use std::{
     },
 };
 
-use futures::StreamExt;
-use smol::channel;
-
-use crate::{
-    executor, platform, Entity, FontCache, Handle, MutableAppContext, Platform, Subscription,
-    TestAppContext,
-};
-
 #[cfg(test)]
 #[ctor::ctor]
 fn init_logger() {
@@ -65,24 +64,27 @@ pub fn run_test(
                 }
 
                 let deterministic = executor::Deterministic::new(seed);
+                let leak_detector = Arc::new(Mutex::new(LeakDetector::default()));
                 let mut cx = TestAppContext::new(
                     foreground_platform.clone(),
                     platform.clone(),
                     deterministic.build_foreground(usize::MAX),
                     deterministic.build_background(),
                     font_cache.clone(),
+                    leak_detector.clone(),
                     0,
                 );
                 cx.update(|cx| {
                     test_fn(
                         cx,
                         foreground_platform.clone(),
-                        deterministic,
+                        deterministic.clone(),
                         seed,
                         is_last_iteration,
                     )
                 });
 
+                leak_detector.lock().detect();
                 if is_last_iteration {
                     break;
                 }

crates/gpui/src/util.rs 🔗

@@ -1,5 +1,6 @@
+use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
 use smol::future::FutureExt;
-use std::{future::Future, time::Duration};
+use std::{fmt, future::Future, time::Duration};
 
 pub fn post_inc(value: &mut usize) -> usize {
     let prev = *value;
@@ -18,3 +19,27 @@ where
     let future = async move { Ok(f.await) };
     timer.race(future).await
 }
+
+pub struct CwdBacktrace<'a>(pub &'a Backtrace);
+
+impl<'a> std::fmt::Debug for CwdBacktrace<'a> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+        let cwd = std::env::current_dir().unwrap();
+        let cwd = cwd.parent().unwrap();
+        let mut print_path = |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
+            fmt::Display::fmt(&path, fmt)
+        };
+        let mut fmt = BacktraceFmt::new(f, backtrace::PrintFmt::Full, &mut print_path);
+        for frame in self.0.frames() {
+            let mut formatted_frame = fmt.frame();
+            if frame
+                .symbols()
+                .iter()
+                .any(|s| s.filename().map_or(false, |f| f.starts_with(&cwd)))
+            {
+                formatted_frame.backtrace_frame(frame)?;
+            }
+        }
+        fmt.finish()
+    }
+}

crates/gpui_macros/src/gpui_macros.rs 🔗

@@ -80,6 +80,7 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
                                     deterministic.build_foreground(#ix),
                                     deterministic.build_background(),
                                     cx.font_cache().clone(),
+                                    cx.leak_detector(),
                                     #first_entity_id,
                                 ),
                             ));

crates/server/src/rpc.rs 🔗

@@ -4213,6 +4213,7 @@ mod tests {
             cx.foreground(),
             cx.background(),
             cx.font_cache(),
+            cx.leak_detector(),
             next_entity_id,
         );
         let host = server.create_client(&mut host_cx, "host").await;
@@ -4249,7 +4250,7 @@ mod tests {
             operations.clone(),
             max_operations,
             rng.clone(),
-            host_cx.clone(),
+            host_cx,
         )));
 
         while operations.get() < max_operations {
@@ -4266,6 +4267,7 @@ mod tests {
                     cx.foreground(),
                     cx.background(),
                     cx.font_cache(),
+                    cx.leak_detector(),
                     next_entity_id,
                 );
                 let guest = server
@@ -4276,7 +4278,7 @@ mod tests {
                     guest.client.clone(),
                     guest.user_store.clone(),
                     guest_lang_registry.clone(),
-                    fs.clone(),
+                    FakeFs::new(cx.background()),
                     &mut guest_cx.to_async(),
                 )
                 .await
@@ -4294,9 +4296,10 @@ mod tests {
             }
         }
 
-        let clients = futures::future::join_all(clients).await;
+        let mut clients = futures::future::join_all(clients).await;
         cx.foreground().run_until_parked();
 
+        let (_, host_cx) = clients.remove(0);
         let host_worktree_snapshots = host_project.read_with(&host_cx, |project, cx| {
             project
                 .worktrees(cx)
@@ -4307,7 +4310,7 @@ mod tests {
                 .collect::<BTreeMap<_, _>>()
         });
 
-        for (guest_client, guest_cx) in clients.iter().skip(1) {
+        for (guest_client, guest_cx) in clients.iter() {
             let guest_id = guest_client.client.id();
             let worktree_snapshots =
                 guest_client