entity_map.rs

   1use crate::{App, AppContext, GpuiBorrow, VisualContext, Window, seal::Sealed};
   2use anyhow::{Context as _, Result};
   3use collections::FxHashSet;
   4use derive_more::{Deref, DerefMut};
   5use parking_lot::{RwLock, RwLockUpgradableReadGuard};
   6use slotmap::{KeyData, SecondaryMap, SlotMap};
   7use std::{
   8    any::{Any, TypeId, type_name},
   9    cell::RefCell,
  10    cmp::Ordering,
  11    fmt::{self, Display},
  12    hash::{Hash, Hasher},
  13    marker::PhantomData,
  14    num::NonZeroU64,
  15    sync::{
  16        Arc, Weak,
  17        atomic::{AtomicU64, AtomicUsize, Ordering::SeqCst},
  18    },
  19    thread::panicking,
  20};
  21
  22use super::Context;
  23use crate::util::atomic_incr_if_not_zero;
  24#[cfg(any(test, feature = "leak-detection"))]
  25use collections::HashMap;
  26
  27slotmap::new_key_type! {
  28    /// A unique identifier for a entity across the application.
  29    pub struct EntityId;
  30}
  31
  32impl From<u64> for EntityId {
  33    fn from(value: u64) -> Self {
  34        Self(KeyData::from_ffi(value))
  35    }
  36}
  37
  38impl EntityId {
  39    /// Converts this entity id to a [NonZeroU64]
  40    pub fn as_non_zero_u64(self) -> NonZeroU64 {
  41        NonZeroU64::new(self.0.as_ffi()).unwrap()
  42    }
  43
  44    /// Converts this entity id to a [u64]
  45    pub fn as_u64(self) -> u64 {
  46        self.0.as_ffi()
  47    }
  48}
  49
  50impl Display for EntityId {
  51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  52        write!(f, "{}", self.as_u64())
  53    }
  54}
  55
  56pub(crate) struct EntityMap {
  57    entities: SecondaryMap<EntityId, Box<dyn Any>>,
  58    pub accessed_entities: RefCell<FxHashSet<EntityId>>,
  59    ref_counts: Arc<RwLock<EntityRefCounts>>,
  60}
  61
  62#[doc(hidden)]
  63pub(crate) struct EntityRefCounts {
  64    counts: SlotMap<EntityId, AtomicUsize>,
  65    dropped_entity_ids: Vec<EntityId>,
  66    #[cfg(any(test, feature = "leak-detection"))]
  67    leak_detector: LeakDetector,
  68}
  69
  70impl EntityMap {
  71    pub fn new() -> Self {
  72        Self {
  73            entities: SecondaryMap::new(),
  74            accessed_entities: RefCell::new(FxHashSet::default()),
  75            ref_counts: Arc::new(RwLock::new(EntityRefCounts {
  76                counts: SlotMap::with_key(),
  77                dropped_entity_ids: Vec::new(),
  78                #[cfg(any(test, feature = "leak-detection"))]
  79                leak_detector: LeakDetector {
  80                    next_handle_id: 0,
  81                    entity_handles: HashMap::default(),
  82                },
  83            })),
  84        }
  85    }
  86
  87    #[doc(hidden)]
  88    pub fn ref_counts_drop_handle(&self) -> Arc<RwLock<EntityRefCounts>> {
  89        self.ref_counts.clone()
  90    }
  91
  92    /// Captures a snapshot of all entities that currently have alive handles.
  93    ///
  94    /// The returned [`LeakDetectorSnapshot`] can later be passed to
  95    /// [`assert_no_new_leaks`](Self::assert_no_new_leaks) to verify that no
  96    /// entities created after the snapshot are still alive.
  97    #[cfg(any(test, feature = "leak-detection"))]
  98    pub fn leak_detector_snapshot(&self) -> LeakDetectorSnapshot {
  99        self.ref_counts.read().leak_detector.snapshot()
 100    }
 101
 102    /// Asserts that no entities created after `snapshot` still have alive handles.
 103    ///
 104    /// See [`LeakDetector::assert_no_new_leaks`] for details.
 105    #[cfg(any(test, feature = "leak-detection"))]
 106    pub fn assert_no_new_leaks(&self, snapshot: &LeakDetectorSnapshot) {
 107        self.ref_counts
 108            .read()
 109            .leak_detector
 110            .assert_no_new_leaks(snapshot)
 111    }
 112
 113    /// Reserve a slot for an entity, which you can subsequently use with `insert`.
 114    pub fn reserve<T: 'static>(&self) -> Slot<T> {
 115        let id = self.ref_counts.write().counts.insert(1.into());
 116        Slot(Entity::new(id, Arc::downgrade(&self.ref_counts)))
 117    }
 118
 119    /// Insert an entity into a slot obtained by calling `reserve`.
 120    pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Entity<T>
 121    where
 122        T: 'static,
 123    {
 124        let mut accessed_entities = self.accessed_entities.get_mut();
 125        accessed_entities.insert(slot.entity_id);
 126
 127        let handle = slot.0;
 128        self.entities.insert(handle.entity_id, Box::new(entity));
 129        handle
 130    }
 131
 132    /// Move an entity to the stack.
 133    #[track_caller]
 134    pub fn lease<T>(&mut self, pointer: &Entity<T>) -> Lease<T> {
 135        self.assert_valid_context(pointer);
 136        let mut accessed_entities = self.accessed_entities.get_mut();
 137        accessed_entities.insert(pointer.entity_id);
 138
 139        let entity = Some(
 140            self.entities
 141                .remove(pointer.entity_id)
 142                .unwrap_or_else(|| double_lease_panic::<T>("update")),
 143        );
 144        Lease {
 145            entity,
 146            id: pointer.entity_id,
 147            entity_type: PhantomData,
 148        }
 149    }
 150
 151    /// Returns an entity after moving it to the stack.
 152    pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
 153        self.entities.insert(lease.id, lease.entity.take().unwrap());
 154    }
 155
 156    pub fn read<T: 'static>(&self, entity: &Entity<T>) -> &T {
 157        self.assert_valid_context(entity);
 158        let mut accessed_entities = self.accessed_entities.borrow_mut();
 159        accessed_entities.insert(entity.entity_id);
 160
 161        self.entities
 162            .get(entity.entity_id)
 163            .and_then(|entity| entity.downcast_ref())
 164            .unwrap_or_else(|| double_lease_panic::<T>("read"))
 165    }
 166
 167    fn assert_valid_context(&self, entity: &AnyEntity) {
 168        debug_assert!(
 169            Weak::ptr_eq(&entity.entity_map, &Arc::downgrade(&self.ref_counts)),
 170            "used a entity with the wrong context"
 171        );
 172    }
 173
 174    pub fn extend_accessed(&mut self, entities: &FxHashSet<EntityId>) {
 175        self.accessed_entities
 176            .get_mut()
 177            .extend(entities.iter().copied());
 178    }
 179
 180    pub fn clear_accessed(&mut self) {
 181        self.accessed_entities.get_mut().clear();
 182    }
 183
 184    pub fn take_dropped(&mut self) -> Vec<(EntityId, Box<dyn Any>)> {
 185        let mut ref_counts = &mut *self.ref_counts.write();
 186        let dropped_entity_ids = ref_counts.dropped_entity_ids.drain(..);
 187        let mut accessed_entities = self.accessed_entities.get_mut();
 188
 189        dropped_entity_ids
 190            .filter_map(|entity_id| {
 191                let count = ref_counts.counts.remove(entity_id).unwrap();
 192                debug_assert_eq!(
 193                    count.load(SeqCst),
 194                    0,
 195                    "dropped an entity that was referenced"
 196                );
 197                accessed_entities.remove(&entity_id);
 198                // If the EntityId was allocated with `Context::reserve`,
 199                // the entity may not have been inserted.
 200                Some((entity_id, self.entities.remove(entity_id)?))
 201            })
 202            .collect()
 203    }
 204}
 205
 206#[track_caller]
 207fn double_lease_panic<T>(operation: &str) -> ! {
 208    panic!(
 209        "cannot {operation} {} while it is already being updated",
 210        std::any::type_name::<T>()
 211    )
 212}
 213
 214pub(crate) struct Lease<T> {
 215    entity: Option<Box<dyn Any>>,
 216    pub id: EntityId,
 217    entity_type: PhantomData<T>,
 218}
 219
 220impl<T: 'static> core::ops::Deref for Lease<T> {
 221    type Target = T;
 222
 223    fn deref(&self) -> &Self::Target {
 224        self.entity.as_ref().unwrap().downcast_ref().unwrap()
 225    }
 226}
 227
 228impl<T: 'static> core::ops::DerefMut for Lease<T> {
 229    fn deref_mut(&mut self) -> &mut Self::Target {
 230        self.entity.as_mut().unwrap().downcast_mut().unwrap()
 231    }
 232}
 233
 234impl<T> Drop for Lease<T> {
 235    fn drop(&mut self) {
 236        if self.entity.is_some() && !panicking() {
 237            panic!("Leases must be ended with EntityMap::end_lease")
 238        }
 239    }
 240}
 241
 242#[derive(Deref, DerefMut)]
 243pub(crate) struct Slot<T>(Entity<T>);
 244
 245/// A dynamically typed reference to a entity, which can be downcast into a `Entity<T>`.
 246pub struct AnyEntity {
 247    pub(crate) entity_id: EntityId,
 248    pub(crate) entity_type: TypeId,
 249    entity_map: Weak<RwLock<EntityRefCounts>>,
 250    #[cfg(any(test, feature = "leak-detection"))]
 251    handle_id: HandleId,
 252}
 253
 254impl AnyEntity {
 255    fn new(
 256        id: EntityId,
 257        entity_type: TypeId,
 258        entity_map: Weak<RwLock<EntityRefCounts>>,
 259        #[cfg(any(test, feature = "leak-detection"))] type_name: &'static str,
 260    ) -> Self {
 261        Self {
 262            entity_id: id,
 263            entity_type,
 264            #[cfg(any(test, feature = "leak-detection"))]
 265            handle_id: entity_map
 266                .clone()
 267                .upgrade()
 268                .unwrap()
 269                .write()
 270                .leak_detector
 271                .handle_created(id, Some(type_name)),
 272            entity_map,
 273        }
 274    }
 275
 276    /// Returns the id associated with this entity.
 277    #[inline]
 278    pub fn entity_id(&self) -> EntityId {
 279        self.entity_id
 280    }
 281
 282    /// Returns the [TypeId] associated with this entity.
 283    #[inline]
 284    pub fn entity_type(&self) -> TypeId {
 285        self.entity_type
 286    }
 287
 288    /// Converts this entity handle into a weak variant, which does not prevent it from being released.
 289    pub fn downgrade(&self) -> AnyWeakEntity {
 290        AnyWeakEntity {
 291            entity_id: self.entity_id,
 292            entity_type: self.entity_type,
 293            entity_ref_counts: self.entity_map.clone(),
 294        }
 295    }
 296
 297    /// Converts this entity handle into a strongly-typed entity handle of the given type.
 298    /// If this entity handle is not of the specified type, returns itself as an error variant.
 299    pub fn downcast<T: 'static>(self) -> Result<Entity<T>, AnyEntity> {
 300        if TypeId::of::<T>() == self.entity_type {
 301            Ok(Entity {
 302                any_entity: self,
 303                entity_type: PhantomData,
 304            })
 305        } else {
 306            Err(self)
 307        }
 308    }
 309}
 310
 311impl Clone for AnyEntity {
 312    fn clone(&self) -> Self {
 313        if let Some(entity_map) = self.entity_map.upgrade() {
 314            let entity_map = entity_map.read();
 315            let count = entity_map
 316                .counts
 317                .get(self.entity_id)
 318                .expect("detected over-release of a entity");
 319            let prev_count = count.fetch_add(1, SeqCst);
 320            assert_ne!(prev_count, 0, "Detected over-release of a entity.");
 321        }
 322
 323        Self {
 324            entity_id: self.entity_id,
 325            entity_type: self.entity_type,
 326            entity_map: self.entity_map.clone(),
 327            #[cfg(any(test, feature = "leak-detection"))]
 328            handle_id: self
 329                .entity_map
 330                .upgrade()
 331                .unwrap()
 332                .write()
 333                .leak_detector
 334                .handle_created(self.entity_id, None),
 335        }
 336    }
 337}
 338
 339impl Drop for AnyEntity {
 340    fn drop(&mut self) {
 341        if let Some(entity_map) = self.entity_map.upgrade() {
 342            let entity_map = entity_map.upgradable_read();
 343            let count = entity_map
 344                .counts
 345                .get(self.entity_id)
 346                .expect("detected over-release of a handle.");
 347            let prev_count = count.fetch_sub(1, SeqCst);
 348            assert_ne!(prev_count, 0, "Detected over-release of a entity.");
 349            if prev_count == 1 {
 350                // We were the last reference to this entity, so we can remove it.
 351                let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
 352                entity_map.dropped_entity_ids.push(self.entity_id);
 353            }
 354        }
 355
 356        #[cfg(any(test, feature = "leak-detection"))]
 357        if let Some(entity_map) = self.entity_map.upgrade() {
 358            entity_map
 359                .write()
 360                .leak_detector
 361                .handle_released(self.entity_id, self.handle_id)
 362        }
 363    }
 364}
 365
 366impl<T> From<Entity<T>> for AnyEntity {
 367    #[inline]
 368    fn from(entity: Entity<T>) -> Self {
 369        entity.any_entity
 370    }
 371}
 372
 373impl Hash for AnyEntity {
 374    #[inline]
 375    fn hash<H: Hasher>(&self, state: &mut H) {
 376        self.entity_id.hash(state);
 377    }
 378}
 379
 380impl PartialEq for AnyEntity {
 381    #[inline]
 382    fn eq(&self, other: &Self) -> bool {
 383        self.entity_id == other.entity_id
 384    }
 385}
 386
 387impl Eq for AnyEntity {}
 388
 389impl Ord for AnyEntity {
 390    #[inline]
 391    fn cmp(&self, other: &Self) -> Ordering {
 392        self.entity_id.cmp(&other.entity_id)
 393    }
 394}
 395
 396impl PartialOrd for AnyEntity {
 397    #[inline]
 398    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
 399        Some(self.cmp(other))
 400    }
 401}
 402
 403impl std::fmt::Debug for AnyEntity {
 404    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 405        f.debug_struct("AnyEntity")
 406            .field("entity_id", &self.entity_id.as_u64())
 407            .finish()
 408    }
 409}
 410
 411/// A strong, well-typed reference to a struct which is managed
 412/// by GPUI
 413#[derive(Deref, DerefMut)]
 414pub struct Entity<T> {
 415    #[deref]
 416    #[deref_mut]
 417    pub(crate) any_entity: AnyEntity,
 418    pub(crate) entity_type: PhantomData<fn(T) -> T>,
 419}
 420
 421impl<T> Sealed for Entity<T> {}
 422
 423impl<T: 'static> Entity<T> {
 424    #[inline]
 425    fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
 426    where
 427        T: 'static,
 428    {
 429        Self {
 430            any_entity: AnyEntity::new(
 431                id,
 432                TypeId::of::<T>(),
 433                entity_map,
 434                #[cfg(any(test, feature = "leak-detection"))]
 435                std::any::type_name::<T>(),
 436            ),
 437            entity_type: PhantomData,
 438        }
 439    }
 440
 441    /// Get the entity ID associated with this entity
 442    #[inline]
 443    pub fn entity_id(&self) -> EntityId {
 444        self.any_entity.entity_id
 445    }
 446
 447    /// Downgrade this entity pointer to a non-retaining weak pointer
 448    #[inline]
 449    pub fn downgrade(&self) -> WeakEntity<T> {
 450        WeakEntity {
 451            any_entity: self.any_entity.downgrade(),
 452            entity_type: self.entity_type,
 453        }
 454    }
 455
 456    /// Convert this into a dynamically typed entity.
 457    #[inline]
 458    pub fn into_any(self) -> AnyEntity {
 459        self.any_entity
 460    }
 461
 462    /// Grab a reference to this entity from the context.
 463    #[inline]
 464    pub fn read<'a>(&self, cx: &'a App) -> &'a T {
 465        cx.entities.read(self)
 466    }
 467
 468    /// Read the entity referenced by this handle with the given function.
 469    #[inline]
 470    pub fn read_with<R, C: AppContext>(&self, cx: &C, f: impl FnOnce(&T, &App) -> R) -> R {
 471        cx.read_entity(self, f)
 472    }
 473
 474    /// Updates the entity referenced by this handle with the given function.
 475    #[inline]
 476    pub fn update<R, C: AppContext>(
 477        &self,
 478        cx: &mut C,
 479        update: impl FnOnce(&mut T, &mut Context<T>) -> R,
 480    ) -> R {
 481        cx.update_entity(self, update)
 482    }
 483
 484    /// Updates the entity referenced by this handle with the given function.
 485    #[inline]
 486    pub fn as_mut<'a, C: AppContext>(&self, cx: &'a mut C) -> GpuiBorrow<'a, T> {
 487        cx.as_mut(self)
 488    }
 489
 490    /// Updates the entity referenced by this handle with the given function.
 491    pub fn write<C: AppContext>(&self, cx: &mut C, value: T) {
 492        self.update(cx, |entity, cx| {
 493            *entity = value;
 494            cx.notify();
 495        })
 496    }
 497
 498    /// Updates the entity referenced by this handle with the given function if
 499    /// the referenced entity still exists, within a visual context that has a window.
 500    /// Returns an error if the window has been closed.
 501    #[inline]
 502    pub fn update_in<R, C: VisualContext>(
 503        &self,
 504        cx: &mut C,
 505        update: impl FnOnce(&mut T, &mut Window, &mut Context<T>) -> R,
 506    ) -> C::Result<R> {
 507        cx.update_window_entity(self, update)
 508    }
 509}
 510
 511impl<T> Clone for Entity<T> {
 512    #[inline]
 513    fn clone(&self) -> Self {
 514        Self {
 515            any_entity: self.any_entity.clone(),
 516            entity_type: self.entity_type,
 517        }
 518    }
 519}
 520
 521impl<T> std::fmt::Debug for Entity<T> {
 522    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 523        f.debug_struct("Entity")
 524            .field("entity_id", &self.any_entity.entity_id)
 525            .field("entity_type", &type_name::<T>())
 526            .finish()
 527    }
 528}
 529
 530impl<T> Hash for Entity<T> {
 531    #[inline]
 532    fn hash<H: Hasher>(&self, state: &mut H) {
 533        self.any_entity.hash(state);
 534    }
 535}
 536
 537impl<T> PartialEq for Entity<T> {
 538    #[inline]
 539    fn eq(&self, other: &Self) -> bool {
 540        self.any_entity == other.any_entity
 541    }
 542}
 543
 544impl<T> Eq for Entity<T> {}
 545
 546impl<T> PartialEq<WeakEntity<T>> for Entity<T> {
 547    #[inline]
 548    fn eq(&self, other: &WeakEntity<T>) -> bool {
 549        self.any_entity.entity_id() == other.entity_id()
 550    }
 551}
 552
 553impl<T: 'static> Ord for Entity<T> {
 554    #[inline]
 555    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
 556        self.entity_id().cmp(&other.entity_id())
 557    }
 558}
 559
 560impl<T: 'static> PartialOrd for Entity<T> {
 561    #[inline]
 562    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
 563        Some(self.cmp(other))
 564    }
 565}
 566
 567/// A type erased, weak reference to a entity.
 568#[derive(Clone)]
 569pub struct AnyWeakEntity {
 570    pub(crate) entity_id: EntityId,
 571    entity_type: TypeId,
 572    entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
 573}
 574
 575impl AnyWeakEntity {
 576    /// Get the entity ID associated with this weak reference.
 577    #[inline]
 578    pub fn entity_id(&self) -> EntityId {
 579        self.entity_id
 580    }
 581
 582    /// Check if this weak handle can be upgraded, or if the entity has already been dropped
 583    pub fn is_upgradable(&self) -> bool {
 584        let ref_count = self
 585            .entity_ref_counts
 586            .upgrade()
 587            .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
 588            .unwrap_or(0);
 589        ref_count > 0
 590    }
 591
 592    /// Upgrade this weak entity reference to a strong reference.
 593    pub fn upgrade(&self) -> Option<AnyEntity> {
 594        let ref_counts = &self.entity_ref_counts.upgrade()?;
 595        let ref_counts = ref_counts.read();
 596        let ref_count = ref_counts.counts.get(self.entity_id)?;
 597
 598        if atomic_incr_if_not_zero(ref_count) == 0 {
 599            // entity_id is in dropped_entity_ids
 600            return None;
 601        }
 602        drop(ref_counts);
 603
 604        Some(AnyEntity {
 605            entity_id: self.entity_id,
 606            entity_type: self.entity_type,
 607            entity_map: self.entity_ref_counts.clone(),
 608            #[cfg(any(test, feature = "leak-detection"))]
 609            handle_id: self
 610                .entity_ref_counts
 611                .upgrade()
 612                .unwrap()
 613                .write()
 614                .leak_detector
 615                .handle_created(self.entity_id, None),
 616        })
 617    }
 618
 619    /// Asserts that the entity referenced by this weak handle has been fully released.
 620    ///
 621    /// # Example
 622    ///
 623    /// ```ignore
 624    /// let entity = cx.new(|_| MyEntity::new());
 625    /// let weak = entity.downgrade();
 626    /// drop(entity);
 627    ///
 628    /// // Verify the entity was released
 629    /// weak.assert_released();
 630    /// ```
 631    ///
 632    /// # Debugging Leaks
 633    ///
 634    /// If this method panics due to leaked handles, set the `LEAK_BACKTRACE` environment
 635    /// variable to see where the leaked handles were allocated:
 636    ///
 637    /// ```bash
 638    /// LEAK_BACKTRACE=1 cargo test my_test
 639    /// ```
 640    ///
 641    /// # Panics
 642    ///
 643    /// - Panics if any strong handles to the entity are still alive.
 644    /// - Panics if the entity was recently dropped but cleanup hasn't completed yet
 645    ///   (resources are retained until the end of the effect cycle).
 646    #[cfg(any(test, feature = "leak-detection"))]
 647    pub fn assert_released(&self) {
 648        self.entity_ref_counts
 649            .upgrade()
 650            .unwrap()
 651            .write()
 652            .leak_detector
 653            .assert_released(self.entity_id);
 654
 655        if self
 656            .entity_ref_counts
 657            .upgrade()
 658            .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
 659            .is_some()
 660        {
 661            panic!(
 662                "entity was recently dropped but resources are retained until the end of the effect cycle."
 663            )
 664        }
 665    }
 666
 667    /// Creates a weak entity that can never be upgraded.
 668    pub fn new_invalid() -> Self {
 669        /// To hold the invariant that all ids are unique, and considering that slotmap
 670        /// increases their IDs from `0`, we can decrease ours from `u64::MAX` so these
 671        /// two will never conflict (u64 is way too large).
 672        static UNIQUE_NON_CONFLICTING_ID_GENERATOR: AtomicU64 = AtomicU64::new(u64::MAX);
 673        let entity_id = UNIQUE_NON_CONFLICTING_ID_GENERATOR.fetch_sub(1, SeqCst);
 674
 675        Self {
 676            // Safety:
 677            //   Docs say this is safe but can be unspecified if slotmap changes the representation
 678            //   after `1.0.7`, that said, providing a valid entity_id here is not necessary as long
 679            //   as we guarantee that `entity_id` is never used if `entity_ref_counts` equals
 680            //   to `Weak::new()` (that is, it's unable to upgrade), that is the invariant that
 681            //   actually needs to be hold true.
 682            //
 683            //   And there is no sane reason to read an entity slot if `entity_ref_counts` can't be
 684            //   read in the first place, so we're good!
 685            entity_id: entity_id.into(),
 686            entity_type: TypeId::of::<()>(),
 687            entity_ref_counts: Weak::new(),
 688        }
 689    }
 690}
 691
 692impl std::fmt::Debug for AnyWeakEntity {
 693    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 694        f.debug_struct(type_name::<Self>())
 695            .field("entity_id", &self.entity_id)
 696            .field("entity_type", &self.entity_type)
 697            .finish()
 698    }
 699}
 700
 701impl<T> From<WeakEntity<T>> for AnyWeakEntity {
 702    #[inline]
 703    fn from(entity: WeakEntity<T>) -> Self {
 704        entity.any_entity
 705    }
 706}
 707
 708impl Hash for AnyWeakEntity {
 709    #[inline]
 710    fn hash<H: Hasher>(&self, state: &mut H) {
 711        self.entity_id.hash(state);
 712    }
 713}
 714
 715impl PartialEq for AnyWeakEntity {
 716    #[inline]
 717    fn eq(&self, other: &Self) -> bool {
 718        self.entity_id == other.entity_id
 719    }
 720}
 721
 722impl Eq for AnyWeakEntity {}
 723
 724impl Ord for AnyWeakEntity {
 725    #[inline]
 726    fn cmp(&self, other: &Self) -> Ordering {
 727        self.entity_id.cmp(&other.entity_id)
 728    }
 729}
 730
 731impl PartialOrd for AnyWeakEntity {
 732    #[inline]
 733    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
 734        Some(self.cmp(other))
 735    }
 736}
 737
 738/// A weak reference to a entity of the given type.
 739#[derive(Deref, DerefMut)]
 740pub struct WeakEntity<T> {
 741    #[deref]
 742    #[deref_mut]
 743    any_entity: AnyWeakEntity,
 744    entity_type: PhantomData<fn(T) -> T>,
 745}
 746
 747impl<T> std::fmt::Debug for WeakEntity<T> {
 748    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 749        f.debug_struct(type_name::<Self>())
 750            .field("entity_id", &self.any_entity.entity_id)
 751            .field("entity_type", &type_name::<T>())
 752            .finish()
 753    }
 754}
 755
 756impl<T> Clone for WeakEntity<T> {
 757    fn clone(&self) -> Self {
 758        Self {
 759            any_entity: self.any_entity.clone(),
 760            entity_type: self.entity_type,
 761        }
 762    }
 763}
 764
 765impl<T: 'static> WeakEntity<T> {
 766    /// Upgrade this weak entity reference into a strong entity reference
 767    pub fn upgrade(&self) -> Option<Entity<T>> {
 768        Some(Entity {
 769            any_entity: self.any_entity.upgrade()?,
 770            entity_type: self.entity_type,
 771        })
 772    }
 773
 774    /// Updates the entity referenced by this handle with the given function if
 775    /// the referenced entity still exists. Returns an error if the entity has
 776    /// been released.
 777    pub fn update<C, R>(
 778        &self,
 779        cx: &mut C,
 780        update: impl FnOnce(&mut T, &mut Context<T>) -> R,
 781    ) -> Result<R>
 782    where
 783        C: AppContext,
 784    {
 785        let entity = self.upgrade().context("entity released")?;
 786        Ok(cx.update_entity(&entity, update))
 787    }
 788
 789    /// Updates the entity referenced by this handle with the given function if
 790    /// the referenced entity still exists, within a visual context that has a window.
 791    /// Returns an error if the entity has been released.
 792    pub fn update_in<C, R>(
 793        &self,
 794        cx: &mut C,
 795        update: impl FnOnce(&mut T, &mut Window, &mut Context<T>) -> R,
 796    ) -> Result<R>
 797    where
 798        C: VisualContext,
 799    {
 800        let window = cx.window_handle();
 801        let entity = self.upgrade().context("entity released")?;
 802
 803        window.update(cx, |_, window, cx| {
 804            entity.update(cx, |entity, cx| update(entity, window, cx))
 805        })
 806    }
 807
 808    /// Reads the entity referenced by this handle with the given function if
 809    /// the referenced entity still exists. Returns an error if the entity has
 810    /// been released.
 811    pub fn read_with<C, R>(&self, cx: &C, read: impl FnOnce(&T, &App) -> R) -> Result<R>
 812    where
 813        C: AppContext,
 814    {
 815        let entity = self.upgrade().context("entity released")?;
 816        Ok(cx.read_entity(&entity, read))
 817    }
 818
 819    /// Create a new weak entity that can never be upgraded.
 820    #[inline]
 821    pub fn new_invalid() -> Self {
 822        Self {
 823            any_entity: AnyWeakEntity::new_invalid(),
 824            entity_type: PhantomData,
 825        }
 826    }
 827}
 828
 829impl<T> Hash for WeakEntity<T> {
 830    #[inline]
 831    fn hash<H: Hasher>(&self, state: &mut H) {
 832        self.any_entity.hash(state);
 833    }
 834}
 835
 836impl<T> PartialEq for WeakEntity<T> {
 837    #[inline]
 838    fn eq(&self, other: &Self) -> bool {
 839        self.any_entity == other.any_entity
 840    }
 841}
 842
 843impl<T> Eq for WeakEntity<T> {}
 844
 845impl<T> PartialEq<Entity<T>> for WeakEntity<T> {
 846    #[inline]
 847    fn eq(&self, other: &Entity<T>) -> bool {
 848        self.entity_id() == other.any_entity.entity_id()
 849    }
 850}
 851
 852impl<T: 'static> Ord for WeakEntity<T> {
 853    #[inline]
 854    fn cmp(&self, other: &Self) -> Ordering {
 855        self.entity_id().cmp(&other.entity_id())
 856    }
 857}
 858
 859impl<T: 'static> PartialOrd for WeakEntity<T> {
 860    #[inline]
 861    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
 862        Some(self.cmp(other))
 863    }
 864}
 865
 866/// Controls whether backtraces are captured when entity handles are created.
 867///
 868/// Set the `LEAK_BACKTRACE` environment variable to any non-empty value to enable
 869/// backtrace capture. This helps identify where leaked handles were allocated.
 870#[cfg(any(test, feature = "leak-detection"))]
 871static LEAK_BACKTRACE: std::sync::LazyLock<bool> =
 872    std::sync::LazyLock::new(|| std::env::var("LEAK_BACKTRACE").is_ok_and(|b| !b.is_empty()));
 873
 874/// Unique identifier for a specific entity handle instance.
 875///
 876/// This is distinct from `EntityId` - while multiple handles can point to the same
 877/// entity (same `EntityId`), each handle has its own unique `HandleId`.
 878#[cfg(any(test, feature = "leak-detection"))]
 879#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)]
 880pub(crate) struct HandleId {
 881    id: u64,
 882}
 883
 884/// Tracks entity handle allocations to detect leaks.
 885///
 886/// The leak detector is enabled in tests and when the `leak-detection` feature is active.
 887/// It tracks every `Entity<T>` and `AnyEntity` handle that is created and released,
 888/// allowing you to verify that all handles to an entity have been properly dropped.
 889///
 890/// # How do leaks happen?
 891///
 892/// Entities are reference-counted structures that can own other entities
 893/// allowing to form cycles. If such a strong-reference counted cycle is
 894/// created, all participating strong entities in this cycle will effectively
 895/// leak as they cannot be released anymore.
 896///
 897/// Cycles can also happen if an entity owns a task or subscription that it
 898/// itself owns a strong reference to the entity again.
 899///
 900/// # Usage
 901///
 902/// You can use `WeakEntity::assert_released` or `AnyWeakEntity::assert_released`
 903/// to verify that an entity has been fully released:
 904///
 905/// ```ignore
 906/// let entity = cx.new(|_| MyEntity::new());
 907/// let weak = entity.downgrade();
 908/// drop(entity);
 909///
 910/// // This will panic if any handles to the entity are still alive
 911/// weak.assert_released();
 912/// ```
 913///
 914/// # Debugging Leaks
 915///
 916/// When a leak is detected, the detector will panic with information about the leaked
 917/// handles. To see where the leaked handles were allocated, set the `LEAK_BACKTRACE`
 918/// environment variable:
 919///
 920/// ```bash
 921/// LEAK_BACKTRACE=1 cargo test my_test
 922/// ```
 923///
 924/// This will capture and display backtraces for each leaked handle, helping you
 925/// identify where leaked handles were created.
 926///
 927/// # How It Works
 928///
 929/// - When an entity handle is created (via `Entity::new`, `Entity::clone`, or
 930///   `WeakEntity::upgrade`), `handle_created` is called to register the handle.
 931/// - When a handle is dropped, `handle_released` removes it from tracking.
 932/// - `assert_released` verifies that no handles remain for a given entity.
 933#[cfg(any(test, feature = "leak-detection"))]
 934pub(crate) struct LeakDetector {
 935    next_handle_id: u64,
 936    entity_handles: HashMap<EntityId, EntityLeakData>,
 937}
 938
 939/// A snapshot of the set of alive entities at a point in time.
 940///
 941/// Created by [`LeakDetector::snapshot`]. Can later be passed to
 942/// [`LeakDetector::assert_no_new_leaks`] to verify that no new entity
 943/// handles remain between the snapshot and the current state.
 944#[cfg(any(test, feature = "leak-detection"))]
 945pub struct LeakDetectorSnapshot {
 946    entity_ids: collections::HashSet<EntityId>,
 947}
 948
 949#[cfg(any(test, feature = "leak-detection"))]
 950struct EntityLeakData {
 951    handles: HashMap<HandleId, Option<backtrace::Backtrace>>,
 952    type_name: &'static str,
 953}
 954
 955#[cfg(any(test, feature = "leak-detection"))]
 956impl LeakDetector {
 957    /// Records that a new handle has been created for the given entity.
 958    ///
 959    /// Returns a unique `HandleId` that must be passed to `handle_released` when
 960    /// the handle is dropped. If `LEAK_BACKTRACE` is set, captures a backtrace
 961    /// at the allocation site.
 962    #[track_caller]
 963    pub fn handle_created(
 964        &mut self,
 965        entity_id: EntityId,
 966        type_name: Option<&'static str>,
 967    ) -> HandleId {
 968        let id = gpui_util::post_inc(&mut self.next_handle_id);
 969        let handle_id = HandleId { id };
 970        let handles = self
 971            .entity_handles
 972            .entry(entity_id)
 973            .or_insert_with(|| EntityLeakData {
 974                handles: HashMap::default(),
 975                type_name: type_name.unwrap_or("<unknown>"),
 976            });
 977        handles.handles.insert(
 978            handle_id,
 979            LEAK_BACKTRACE.then(backtrace::Backtrace::new_unresolved),
 980        );
 981        handle_id
 982    }
 983
 984    /// Records that a handle has been released (dropped).
 985    ///
 986    /// This removes the handle from tracking. The `handle_id` should be the same
 987    /// one returned by `handle_created` when the handle was allocated.
 988    pub fn handle_released(&mut self, entity_id: EntityId, handle_id: HandleId) {
 989        if let std::collections::hash_map::Entry::Occupied(mut data) =
 990            self.entity_handles.entry(entity_id)
 991        {
 992            data.get_mut().handles.remove(&handle_id);
 993            if data.get().handles.is_empty() {
 994                data.remove();
 995            }
 996        }
 997    }
 998
 999    /// Asserts that all handles to the given entity have been released.
1000    ///
1001    /// # Panics
1002    ///
1003    /// Panics if any handles to the entity are still alive. The panic message
1004    /// includes backtraces for each leaked handle if `LEAK_BACKTRACE` is set,
1005    /// otherwise it suggests setting the environment variable to get more info.
1006    pub fn assert_released(&mut self, entity_id: EntityId) {
1007        use std::fmt::Write as _;
1008
1009        if let Some(data) = self.entity_handles.remove(&entity_id) {
1010            let mut out = String::new();
1011            for (_, backtrace) in data.handles {
1012                if let Some(mut backtrace) = backtrace {
1013                    backtrace.resolve();
1014                    let backtrace = BacktraceFormatter(backtrace);
1015                    writeln!(out, "Leaked handle:\n{:?}", backtrace).unwrap();
1016                } else {
1017                    writeln!(
1018                        out,
1019                        "Leaked handle: (export LEAK_BACKTRACE to find allocation site)"
1020                    )
1021                    .unwrap();
1022                }
1023            }
1024            panic!("Handles for {} leaked:\n{out}", data.type_name);
1025        }
1026    }
1027
1028    /// Captures a snapshot of all entity IDs that currently have alive handles.
1029    ///
1030    /// The returned [`LeakDetectorSnapshot`] can later be passed to
1031    /// [`assert_no_new_leaks`](Self::assert_no_new_leaks) to verify that no
1032    /// entities created after the snapshot are still alive.
1033    pub fn snapshot(&self) -> LeakDetectorSnapshot {
1034        LeakDetectorSnapshot {
1035            entity_ids: self.entity_handles.keys().copied().collect(),
1036        }
1037    }
1038
1039    /// Asserts that no entities created after `snapshot` still have alive handles.
1040    ///
1041    /// Entities that were already tracked at the time of the snapshot are ignored,
1042    /// even if they still have handles. Only *new* entities (those whose
1043    /// `EntityId` was not present in the snapshot) are considered leaks.
1044    ///
1045    /// # Panics
1046    ///
1047    /// Panics if any new entity handles exist. The panic message lists every
1048    /// leaked entity with its type name, and includes allocation-site backtraces
1049    /// when `LEAK_BACKTRACE` is set.
1050    pub fn assert_no_new_leaks(&self, snapshot: &LeakDetectorSnapshot) {
1051        use std::fmt::Write as _;
1052
1053        let mut out = String::new();
1054        for (entity_id, data) in &self.entity_handles {
1055            if snapshot.entity_ids.contains(entity_id) {
1056                continue;
1057            }
1058            for (_, backtrace) in &data.handles {
1059                if let Some(backtrace) = backtrace {
1060                    let mut backtrace = backtrace.clone();
1061                    backtrace.resolve();
1062                    let backtrace = BacktraceFormatter(backtrace);
1063                    writeln!(
1064                        out,
1065                        "Leaked handle for entity {} ({entity_id:?}):\n{:?}",
1066                        data.type_name, backtrace
1067                    )
1068                    .unwrap();
1069                } else {
1070                    writeln!(
1071                        out,
1072                        "Leaked handle for entity {} ({entity_id:?}): (export LEAK_BACKTRACE to find allocation site)",
1073                        data.type_name
1074                    )
1075                    .unwrap();
1076                }
1077            }
1078        }
1079
1080        if !out.is_empty() {
1081            panic!("New entity leaks detected since snapshot:\n{out}");
1082        }
1083    }
1084}
1085
1086#[cfg(any(test, feature = "leak-detection"))]
1087impl Drop for LeakDetector {
1088    fn drop(&mut self) {
1089        use std::fmt::Write;
1090
1091        if self.entity_handles.is_empty() || std::thread::panicking() {
1092            return;
1093        }
1094
1095        let mut out = String::new();
1096        for (entity_id, data) in self.entity_handles.drain() {
1097            for (_handle, backtrace) in data.handles {
1098                if let Some(mut backtrace) = backtrace {
1099                    backtrace.resolve();
1100                    let backtrace = BacktraceFormatter(backtrace);
1101                    writeln!(
1102                        out,
1103                        "Leaked handle for entity {} ({entity_id:?}):\n{:?}",
1104                        data.type_name, backtrace
1105                    )
1106                    .unwrap();
1107                } else {
1108                    writeln!(
1109                        out,
1110                        "Leaked handle for entity {} ({entity_id:?}): (export LEAK_BACKTRACE to find allocation site)",
1111                        data.type_name
1112                    )
1113                    .unwrap();
1114                }
1115            }
1116        }
1117        panic!("Exited with leaked handles:\n{out}");
1118    }
1119}
1120
1121#[cfg(any(test, feature = "leak-detection"))]
1122struct BacktraceFormatter(backtrace::Backtrace);
1123
1124#[cfg(any(test, feature = "leak-detection"))]
1125impl fmt::Debug for BacktraceFormatter {
1126    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1127        use backtrace::{BacktraceFmt, BytesOrWideString, PrintFmt};
1128
1129        let style = if fmt.alternate() {
1130            PrintFmt::Full
1131        } else {
1132            PrintFmt::Short
1133        };
1134
1135        // When printing paths we try to strip the cwd if it exists, otherwise
1136        // we just print the path as-is. Note that we also only do this for the
1137        // short format, because if it's full we presumably want to print
1138        // everything.
1139        let cwd = std::env::current_dir();
1140        let mut print_path = move |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
1141            let path = path.into_path_buf();
1142            if style != PrintFmt::Full {
1143                if let Ok(cwd) = &cwd {
1144                    if let Ok(suffix) = path.strip_prefix(cwd) {
1145                        return fmt::Display::fmt(&suffix.display(), fmt);
1146                    }
1147                }
1148            }
1149            fmt::Display::fmt(&path.display(), fmt)
1150        };
1151
1152        let mut f = BacktraceFmt::new(fmt, style, &mut print_path);
1153        f.add_context()?;
1154        let mut strip = true;
1155        for frame in self.0.frames() {
1156            if let [symbol, ..] = frame.symbols()
1157                && let Some(name) = symbol.name()
1158                && let Some(filename) = name.as_str()
1159            {
1160                match filename {
1161                    "test::run_test_in_process"
1162                    | "scheduler::executor::spawn_local_with_source_location::impl$1::poll<core::pin::Pin<alloc::boxed::Box<dyn$<core::future::future::Future<assoc$<Output,enum2$<core::result::Result<workspace::OpenResult,anyhow::Error> > > > >,alloc::alloc::Global> > >" => {
1163                        strip = true
1164                    }
1165                    "gpui::app::entity_map::LeakDetector::handle_created" => {
1166                        strip = false;
1167                        continue;
1168                    }
1169                    "zed::main" => {
1170                        strip = true;
1171                        f.frame().backtrace_frame(frame)?;
1172                    }
1173                    _ => {}
1174                }
1175            }
1176            if strip {
1177                continue;
1178            }
1179            f.frame().backtrace_frame(frame)?;
1180        }
1181        f.finish()?;
1182        Ok(())
1183    }
1184}
1185
1186#[cfg(test)]
1187mod test {
1188    use crate::EntityMap;
1189
1190    struct TestEntity {
1191        pub i: i32,
1192    }
1193
1194    #[test]
1195    fn test_entity_map_slot_assignment_before_cleanup() {
1196        // Tests that slots are not re-used before take_dropped.
1197        let mut entity_map = EntityMap::new();
1198
1199        let slot = entity_map.reserve::<TestEntity>();
1200        entity_map.insert(slot, TestEntity { i: 1 });
1201
1202        let slot = entity_map.reserve::<TestEntity>();
1203        entity_map.insert(slot, TestEntity { i: 2 });
1204
1205        let dropped = entity_map.take_dropped();
1206        assert_eq!(dropped.len(), 2);
1207
1208        assert_eq!(
1209            dropped
1210                .into_iter()
1211                .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
1212                .collect::<Vec<i32>>(),
1213            vec![1, 2],
1214        );
1215    }
1216
1217    #[test]
1218    fn test_entity_map_weak_upgrade_before_cleanup() {
1219        // Tests that weak handles are not upgraded before take_dropped
1220        let mut entity_map = EntityMap::new();
1221
1222        let slot = entity_map.reserve::<TestEntity>();
1223        let handle = entity_map.insert(slot, TestEntity { i: 1 });
1224        let weak = handle.downgrade();
1225        drop(handle);
1226
1227        let strong = weak.upgrade();
1228        assert_eq!(strong, None);
1229
1230        let dropped = entity_map.take_dropped();
1231        assert_eq!(dropped.len(), 1);
1232
1233        assert_eq!(
1234            dropped
1235                .into_iter()
1236                .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
1237                .collect::<Vec<i32>>(),
1238            vec![1],
1239        );
1240    }
1241
1242    #[test]
1243    fn test_leak_detector_snapshot_no_leaks() {
1244        let mut entity_map = EntityMap::new();
1245
1246        let slot = entity_map.reserve::<TestEntity>();
1247        let pre_existing = entity_map.insert(slot, TestEntity { i: 1 });
1248
1249        let snapshot = entity_map.leak_detector_snapshot();
1250
1251        let slot = entity_map.reserve::<TestEntity>();
1252        let temporary = entity_map.insert(slot, TestEntity { i: 2 });
1253        drop(temporary);
1254
1255        entity_map.assert_no_new_leaks(&snapshot);
1256
1257        drop(pre_existing);
1258    }
1259
1260    #[test]
1261    #[should_panic(expected = "New entity leaks detected since snapshot")]
1262    fn test_leak_detector_snapshot_detects_new_leak() {
1263        let mut entity_map = EntityMap::new();
1264
1265        let slot = entity_map.reserve::<TestEntity>();
1266        let pre_existing = entity_map.insert(slot, TestEntity { i: 1 });
1267
1268        let snapshot = entity_map.leak_detector_snapshot();
1269
1270        let slot = entity_map.reserve::<TestEntity>();
1271        let leaked = entity_map.insert(slot, TestEntity { i: 2 });
1272
1273        // `leaked` is still alive, so this should panic.
1274        entity_map.assert_no_new_leaks(&snapshot);
1275
1276        drop(pre_existing);
1277        drop(leaked);
1278    }
1279}