entity_map.rs

  1use crate::{private::Sealed, AppContext, Context, Entity, ModelContext};
  2use anyhow::{anyhow, Result};
  3use collections::HashMap;
  4use derive_more::{Deref, DerefMut};
  5use lazy_static::lazy_static;
  6use parking_lot::{RwLock, RwLockUpgradableReadGuard};
  7use slotmap::{SecondaryMap, SlotMap};
  8use std::{
  9    any::{type_name, Any, TypeId},
 10    fmt::{self, Display},
 11    hash::{Hash, Hasher},
 12    marker::PhantomData,
 13    mem,
 14    sync::{
 15        atomic::{AtomicUsize, Ordering::SeqCst},
 16        Arc, Weak,
 17    },
 18    thread::panicking,
 19};
 20
 21slotmap::new_key_type! { pub struct EntityId; }
 22
 23impl EntityId {
 24    pub fn as_u64(self) -> u64 {
 25        self.0.as_ffi()
 26    }
 27}
 28
 29impl Display for EntityId {
 30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 31        write!(f, "{}", self.as_u64())
 32    }
 33}
 34
 35pub(crate) struct EntityMap {
 36    entities: SecondaryMap<EntityId, Box<dyn Any>>,
 37    ref_counts: Arc<RwLock<EntityRefCounts>>,
 38}
 39
 40struct EntityRefCounts {
 41    counts: SlotMap<EntityId, AtomicUsize>,
 42    dropped_entity_ids: Vec<EntityId>,
 43    #[cfg(any(test, feature = "test-support"))]
 44    leak_detector: LeakDetector,
 45}
 46
 47impl EntityMap {
 48    pub fn new() -> Self {
 49        Self {
 50            entities: SecondaryMap::new(),
 51            ref_counts: Arc::new(RwLock::new(EntityRefCounts {
 52                counts: SlotMap::with_key(),
 53                dropped_entity_ids: Vec::new(),
 54                #[cfg(any(test, feature = "test-support"))]
 55                leak_detector: LeakDetector {
 56                    next_handle_id: 0,
 57                    entity_handles: HashMap::default(),
 58                },
 59            })),
 60        }
 61    }
 62
 63    /// Reserve a slot for an entity, which you can subsequently use with `insert`.
 64    pub fn reserve<T: 'static>(&self) -> Slot<T> {
 65        let id = self.ref_counts.write().counts.insert(1.into());
 66        Slot(Model::new(id, Arc::downgrade(&self.ref_counts)))
 67    }
 68
 69    /// Insert an entity into a slot obtained by calling `reserve`.
 70    pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Model<T>
 71    where
 72        T: 'static,
 73    {
 74        let model = slot.0;
 75        self.entities.insert(model.entity_id, Box::new(entity));
 76        model
 77    }
 78
 79    /// Move an entity to the stack.
 80    #[track_caller]
 81    pub fn lease<'a, T>(&mut self, model: &'a Model<T>) -> Lease<'a, T> {
 82        self.assert_valid_context(model);
 83        let entity = Some(self.entities.remove(model.entity_id).unwrap_or_else(|| {
 84            panic!(
 85                "Circular entity lease of {}. Is it already being updated?",
 86                std::any::type_name::<T>()
 87            )
 88        }));
 89        Lease {
 90            model,
 91            entity,
 92            entity_type: PhantomData,
 93        }
 94    }
 95
 96    /// Return an entity after moving it to the stack.
 97    pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
 98        self.entities
 99            .insert(lease.model.entity_id, lease.entity.take().unwrap());
100    }
101
102    pub fn read<T: 'static>(&self, model: &Model<T>) -> &T {
103        self.assert_valid_context(model);
104        self.entities[model.entity_id].downcast_ref().unwrap()
105    }
106
107    fn assert_valid_context(&self, model: &AnyModel) {
108        debug_assert!(
109            Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)),
110            "used a model with the wrong context"
111        );
112    }
113
114    pub fn take_dropped(&mut self) -> Vec<(EntityId, Box<dyn Any>)> {
115        let mut ref_counts = self.ref_counts.write();
116        let dropped_entity_ids = mem::take(&mut ref_counts.dropped_entity_ids);
117
118        dropped_entity_ids
119            .into_iter()
120            .map(|entity_id| {
121                let count = ref_counts.counts.remove(entity_id).unwrap();
122                debug_assert_eq!(
123                    count.load(SeqCst),
124                    0,
125                    "dropped an entity that was referenced"
126                );
127                (entity_id, self.entities.remove(entity_id).unwrap())
128            })
129            .collect()
130    }
131}
132
133pub struct Lease<'a, T> {
134    entity: Option<Box<dyn Any>>,
135    pub model: &'a Model<T>,
136    entity_type: PhantomData<T>,
137}
138
139impl<'a, T: 'static> core::ops::Deref for Lease<'a, T> {
140    type Target = T;
141
142    fn deref(&self) -> &Self::Target {
143        self.entity.as_ref().unwrap().downcast_ref().unwrap()
144    }
145}
146
147impl<'a, T: 'static> core::ops::DerefMut for Lease<'a, T> {
148    fn deref_mut(&mut self) -> &mut Self::Target {
149        self.entity.as_mut().unwrap().downcast_mut().unwrap()
150    }
151}
152
153impl<'a, T> Drop for Lease<'a, T> {
154    fn drop(&mut self) {
155        if self.entity.is_some() && !panicking() {
156            panic!("Leases must be ended with EntityMap::end_lease")
157        }
158    }
159}
160
161#[derive(Deref, DerefMut)]
162pub struct Slot<T>(Model<T>);
163
164pub struct AnyModel {
165    pub(crate) entity_id: EntityId,
166    pub(crate) entity_type: TypeId,
167    entity_map: Weak<RwLock<EntityRefCounts>>,
168    #[cfg(any(test, feature = "test-support"))]
169    handle_id: HandleId,
170}
171
172impl AnyModel {
173    fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
174        Self {
175            entity_id: id,
176            entity_type,
177            entity_map: entity_map.clone(),
178            #[cfg(any(test, feature = "test-support"))]
179            handle_id: entity_map
180                .upgrade()
181                .unwrap()
182                .write()
183                .leak_detector
184                .handle_created(id),
185        }
186    }
187
188    pub fn entity_id(&self) -> EntityId {
189        self.entity_id
190    }
191
192    pub fn entity_type(&self) -> TypeId {
193        self.entity_type
194    }
195
196    pub fn downgrade(&self) -> AnyWeakModel {
197        AnyWeakModel {
198            entity_id: self.entity_id,
199            entity_type: self.entity_type,
200            entity_ref_counts: self.entity_map.clone(),
201        }
202    }
203
204    pub fn downcast<T: 'static>(self) -> Result<Model<T>, AnyModel> {
205        if TypeId::of::<T>() == self.entity_type {
206            Ok(Model {
207                any_model: self,
208                entity_type: PhantomData,
209            })
210        } else {
211            Err(self)
212        }
213    }
214}
215
216impl Clone for AnyModel {
217    fn clone(&self) -> Self {
218        if let Some(entity_map) = self.entity_map.upgrade() {
219            let entity_map = entity_map.read();
220            let count = entity_map
221                .counts
222                .get(self.entity_id)
223                .expect("detected over-release of a model");
224            let prev_count = count.fetch_add(1, SeqCst);
225            assert_ne!(prev_count, 0, "Detected over-release of a model.");
226        }
227
228        let this = Self {
229            entity_id: self.entity_id,
230            entity_type: self.entity_type,
231            entity_map: self.entity_map.clone(),
232            #[cfg(any(test, feature = "test-support"))]
233            handle_id: self
234                .entity_map
235                .upgrade()
236                .unwrap()
237                .write()
238                .leak_detector
239                .handle_created(self.entity_id),
240        };
241        this
242    }
243}
244
245impl Drop for AnyModel {
246    fn drop(&mut self) {
247        if let Some(entity_map) = self.entity_map.upgrade() {
248            let entity_map = entity_map.upgradable_read();
249            let count = entity_map
250                .counts
251                .get(self.entity_id)
252                .expect("detected over-release of a handle.");
253            let prev_count = count.fetch_sub(1, SeqCst);
254            assert_ne!(prev_count, 0, "Detected over-release of a model.");
255            if prev_count == 1 {
256                // We were the last reference to this entity, so we can remove it.
257                let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
258                entity_map.dropped_entity_ids.push(self.entity_id);
259            }
260        }
261
262        #[cfg(any(test, feature = "test-support"))]
263        if let Some(entity_map) = self.entity_map.upgrade() {
264            entity_map
265                .write()
266                .leak_detector
267                .handle_dropped(self.entity_id, self.handle_id)
268        }
269    }
270}
271
272impl<T> From<Model<T>> for AnyModel {
273    fn from(model: Model<T>) -> Self {
274        model.any_model
275    }
276}
277
278impl Hash for AnyModel {
279    fn hash<H: Hasher>(&self, state: &mut H) {
280        self.entity_id.hash(state);
281    }
282}
283
284impl PartialEq for AnyModel {
285    fn eq(&self, other: &Self) -> bool {
286        self.entity_id == other.entity_id
287    }
288}
289
290impl Eq for AnyModel {}
291
292impl std::fmt::Debug for AnyModel {
293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294        f.debug_struct("AnyModel")
295            .field("entity_id", &self.entity_id.as_u64())
296            .finish()
297    }
298}
299
300#[derive(Deref, DerefMut)]
301pub struct Model<T> {
302    #[deref]
303    #[deref_mut]
304    pub(crate) any_model: AnyModel,
305    pub(crate) entity_type: PhantomData<T>,
306}
307
308unsafe impl<T> Send for Model<T> {}
309unsafe impl<T> Sync for Model<T> {}
310impl<T> Sealed for Model<T> {}
311
312impl<T: 'static> Entity<T> for Model<T> {
313    type Weak = WeakModel<T>;
314
315    fn entity_id(&self) -> EntityId {
316        self.any_model.entity_id
317    }
318
319    fn downgrade(&self) -> Self::Weak {
320        WeakModel {
321            any_model: self.any_model.downgrade(),
322            entity_type: self.entity_type,
323        }
324    }
325
326    fn upgrade_from(weak: &Self::Weak) -> Option<Self>
327    where
328        Self: Sized,
329    {
330        Some(Model {
331            any_model: weak.any_model.upgrade()?,
332            entity_type: weak.entity_type,
333        })
334    }
335}
336
337impl<T: 'static> Model<T> {
338    fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
339    where
340        T: 'static,
341    {
342        Self {
343            any_model: AnyModel::new(id, TypeId::of::<T>(), entity_map),
344            entity_type: PhantomData,
345        }
346    }
347
348    /// Downgrade the this to a weak model reference
349    pub fn downgrade(&self) -> WeakModel<T> {
350        // Delegate to the trait implementation to keep behavior in one place.
351        // This method was included to improve method resolution in the presence of
352        // the Model's deref
353        Entity::downgrade(self)
354    }
355
356    /// Convert this into a dynamically typed model.
357    pub fn into_any(self) -> AnyModel {
358        self.any_model
359    }
360
361    pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
362        cx.entities.read(self)
363    }
364
365    pub fn read_with<R, C: Context>(
366        &self,
367        cx: &C,
368        f: impl FnOnce(&T, &AppContext) -> R,
369    ) -> C::Result<R> {
370        cx.read_model(self, f)
371    }
372
373    /// Update the entity referenced by this model with the given function.
374    ///
375    /// The update function receives a context appropriate for its environment.
376    /// When updating in an `AppContext`, it receives a `ModelContext`.
377    /// When updating an a `WindowContext`, it receives a `ViewContext`.
378    pub fn update<C, R>(
379        &self,
380        cx: &mut C,
381        update: impl FnOnce(&mut T, &mut ModelContext<'_, T>) -> R,
382    ) -> C::Result<R>
383    where
384        C: Context,
385    {
386        cx.update_model(self, update)
387    }
388}
389
390impl<T> Clone for Model<T> {
391    fn clone(&self) -> Self {
392        Self {
393            any_model: self.any_model.clone(),
394            entity_type: self.entity_type,
395        }
396    }
397}
398
399impl<T> std::fmt::Debug for Model<T> {
400    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401        write!(
402            f,
403            "Model {{ entity_id: {:?}, entity_type: {:?} }}",
404            self.any_model.entity_id,
405            type_name::<T>()
406        )
407    }
408}
409
410impl<T> Hash for Model<T> {
411    fn hash<H: Hasher>(&self, state: &mut H) {
412        self.any_model.hash(state);
413    }
414}
415
416impl<T> PartialEq for Model<T> {
417    fn eq(&self, other: &Self) -> bool {
418        self.any_model == other.any_model
419    }
420}
421
422impl<T> Eq for Model<T> {}
423
424impl<T> PartialEq<WeakModel<T>> for Model<T> {
425    fn eq(&self, other: &WeakModel<T>) -> bool {
426        self.any_model.entity_id() == other.entity_id()
427    }
428}
429
430#[derive(Clone)]
431pub struct AnyWeakModel {
432    pub(crate) entity_id: EntityId,
433    entity_type: TypeId,
434    entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
435}
436
437impl AnyWeakModel {
438    pub fn entity_id(&self) -> EntityId {
439        self.entity_id
440    }
441
442    pub fn is_upgradable(&self) -> bool {
443        let ref_count = self
444            .entity_ref_counts
445            .upgrade()
446            .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
447            .unwrap_or(0);
448        ref_count > 0
449    }
450
451    pub fn upgrade(&self) -> Option<AnyModel> {
452        let ref_counts = &self.entity_ref_counts.upgrade()?;
453        let ref_counts = ref_counts.read();
454        let ref_count = ref_counts.counts.get(self.entity_id)?;
455
456        // entity_id is in dropped_entity_ids
457        if ref_count.load(SeqCst) == 0 {
458            return None;
459        }
460        ref_count.fetch_add(1, SeqCst);
461        drop(ref_counts);
462
463        Some(AnyModel {
464            entity_id: self.entity_id,
465            entity_type: self.entity_type,
466            entity_map: self.entity_ref_counts.clone(),
467            #[cfg(any(test, feature = "test-support"))]
468            handle_id: self
469                .entity_ref_counts
470                .upgrade()
471                .unwrap()
472                .write()
473                .leak_detector
474                .handle_created(self.entity_id),
475        })
476    }
477
478    #[cfg(any(test, feature = "test-support"))]
479    pub fn assert_dropped(&self) {
480        self.entity_ref_counts
481            .upgrade()
482            .unwrap()
483            .write()
484            .leak_detector
485            .assert_dropped(self.entity_id);
486
487        if self
488            .entity_ref_counts
489            .upgrade()
490            .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
491            .is_some()
492        {
493            panic!(
494                "entity was recently dropped but resources are retained until the end of the effect cycle."
495            )
496        }
497    }
498}
499
500impl<T> From<WeakModel<T>> for AnyWeakModel {
501    fn from(model: WeakModel<T>) -> Self {
502        model.any_model
503    }
504}
505
506impl Hash for AnyWeakModel {
507    fn hash<H: Hasher>(&self, state: &mut H) {
508        self.entity_id.hash(state);
509    }
510}
511
512impl PartialEq for AnyWeakModel {
513    fn eq(&self, other: &Self) -> bool {
514        self.entity_id == other.entity_id
515    }
516}
517
518impl Eq for AnyWeakModel {}
519
520#[derive(Deref, DerefMut)]
521pub struct WeakModel<T> {
522    #[deref]
523    #[deref_mut]
524    any_model: AnyWeakModel,
525    entity_type: PhantomData<T>,
526}
527
528unsafe impl<T> Send for WeakModel<T> {}
529unsafe impl<T> Sync for WeakModel<T> {}
530
531impl<T> Clone for WeakModel<T> {
532    fn clone(&self) -> Self {
533        Self {
534            any_model: self.any_model.clone(),
535            entity_type: self.entity_type,
536        }
537    }
538}
539
540impl<T: 'static> WeakModel<T> {
541    /// Upgrade this weak model reference into a strong model reference
542    pub fn upgrade(&self) -> Option<Model<T>> {
543        // Delegate to the trait implementation to keep behavior in one place.
544        Model::upgrade_from(self)
545    }
546
547    /// Update the entity referenced by this model with the given function if
548    /// the referenced entity still exists. Returns an error if the entity has
549    /// been released.
550    pub fn update<C, R>(
551        &self,
552        cx: &mut C,
553        update: impl FnOnce(&mut T, &mut ModelContext<'_, T>) -> R,
554    ) -> Result<R>
555    where
556        C: Context,
557        Result<C::Result<R>>: crate::Flatten<R>,
558    {
559        crate::Flatten::flatten(
560            self.upgrade()
561                .ok_or_else(|| anyhow!("entity release"))
562                .map(|this| cx.update_model(&this, update)),
563        )
564    }
565
566    /// Reads the entity referenced by this model with the given function if
567    /// the referenced entity still exists. Returns an error if the entity has
568    /// been released.
569    pub fn read_with<C, R>(&self, cx: &C, read: impl FnOnce(&T, &AppContext) -> R) -> Result<R>
570    where
571        C: Context,
572        Result<C::Result<R>>: crate::Flatten<R>,
573    {
574        crate::Flatten::flatten(
575            self.upgrade()
576                .ok_or_else(|| anyhow!("entity release"))
577                .map(|this| cx.read_model(&this, read)),
578        )
579    }
580}
581
582impl<T> Hash for WeakModel<T> {
583    fn hash<H: Hasher>(&self, state: &mut H) {
584        self.any_model.hash(state);
585    }
586}
587
588impl<T> PartialEq for WeakModel<T> {
589    fn eq(&self, other: &Self) -> bool {
590        self.any_model == other.any_model
591    }
592}
593
594impl<T> Eq for WeakModel<T> {}
595
596impl<T> PartialEq<Model<T>> for WeakModel<T> {
597    fn eq(&self, other: &Model<T>) -> bool {
598        self.entity_id() == other.any_model.entity_id()
599    }
600}
601
602#[cfg(any(test, feature = "test-support"))]
603lazy_static! {
604    static ref LEAK_BACKTRACE: bool =
605        std::env::var("LEAK_BACKTRACE").map_or(false, |b| !b.is_empty());
606}
607
608#[cfg(any(test, feature = "test-support"))]
609#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)]
610pub struct HandleId {
611    id: u64, // id of the handle itself, not the pointed at object
612}
613
614#[cfg(any(test, feature = "test-support"))]
615pub struct LeakDetector {
616    next_handle_id: u64,
617    entity_handles: HashMap<EntityId, HashMap<HandleId, Option<backtrace::Backtrace>>>,
618}
619
620#[cfg(any(test, feature = "test-support"))]
621impl LeakDetector {
622    #[track_caller]
623    pub fn handle_created(&mut self, entity_id: EntityId) -> HandleId {
624        let id = util::post_inc(&mut self.next_handle_id);
625        let handle_id = HandleId { id };
626        let handles = self.entity_handles.entry(entity_id).or_default();
627        handles.insert(
628            handle_id,
629            LEAK_BACKTRACE.then(|| backtrace::Backtrace::new_unresolved()),
630        );
631        handle_id
632    }
633
634    pub fn handle_dropped(&mut self, entity_id: EntityId, handle_id: HandleId) {
635        let handles = self.entity_handles.entry(entity_id).or_default();
636        handles.remove(&handle_id);
637    }
638
639    pub fn assert_dropped(&mut self, entity_id: EntityId) {
640        let handles = self.entity_handles.entry(entity_id).or_default();
641        if !handles.is_empty() {
642            for (_, backtrace) in handles {
643                if let Some(mut backtrace) = backtrace.take() {
644                    backtrace.resolve();
645                    eprintln!("Leaked handle: {:#?}", backtrace);
646                } else {
647                    eprintln!("Leaked handle: export LEAK_BACKTRACE to find allocation site");
648                }
649            }
650            panic!();
651        }
652    }
653}
654
655#[cfg(test)]
656mod test {
657    use crate::EntityMap;
658
659    struct TestEntity {
660        pub i: i32,
661    }
662
663    #[test]
664    fn test_entity_map_slot_assignment_before_cleanup() {
665        // Tests that slots are not re-used before take_dropped.
666        let mut entity_map = EntityMap::new();
667
668        let slot = entity_map.reserve::<TestEntity>();
669        entity_map.insert(slot, TestEntity { i: 1 });
670
671        let slot = entity_map.reserve::<TestEntity>();
672        entity_map.insert(slot, TestEntity { i: 2 });
673
674        let dropped = entity_map.take_dropped();
675        assert_eq!(dropped.len(), 2);
676
677        assert_eq!(
678            dropped
679                .into_iter()
680                .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
681                .collect::<Vec<i32>>(),
682            vec![1, 2],
683        );
684    }
685
686    #[test]
687    fn test_entity_map_weak_upgrade_before_cleanup() {
688        // Tests that weak handles are not upgraded before take_dropped
689        let mut entity_map = EntityMap::new();
690
691        let slot = entity_map.reserve::<TestEntity>();
692        let handle = entity_map.insert(slot, TestEntity { i: 1 });
693        let weak = handle.downgrade();
694        drop(handle);
695
696        let strong = weak.upgrade();
697        assert_eq!(strong, None);
698
699        let dropped = entity_map.take_dropped();
700        assert_eq!(dropped.len(), 1);
701
702        assert_eq!(
703            dropped
704                .into_iter()
705                .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
706                .collect::<Vec<i32>>(),
707            vec![1],
708        );
709    }
710}