entity_map.rs

  1use crate::{seal::Sealed, AppContext, Context, Entity, ModelContext};
  2use anyhow::{anyhow, Result};
  3use derive_more::{Deref, DerefMut};
  4use parking_lot::{RwLock, RwLockUpgradableReadGuard};
  5use slotmap::{SecondaryMap, SlotMap};
  6use std::{
  7    any::{type_name, Any, TypeId},
  8    fmt::{self, Display},
  9    hash::{Hash, Hasher},
 10    marker::PhantomData,
 11    mem,
 12    sync::{
 13        atomic::{AtomicUsize, Ordering::SeqCst},
 14        Arc, Weak,
 15    },
 16    thread::panicking,
 17};
 18
 19#[cfg(any(test, feature = "test-support"))]
 20use collections::HashMap;
 21
 22slotmap::new_key_type! { pub struct EntityId; }
 23
 24impl EntityId {
 25    pub fn as_u64(self) -> u64 {
 26        self.0.as_ffi()
 27    }
 28}
 29
 30impl Display for EntityId {
 31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 32        write!(f, "{}", self.as_u64())
 33    }
 34}
 35
 36pub(crate) struct EntityMap {
 37    entities: SecondaryMap<EntityId, Box<dyn Any>>,
 38    ref_counts: Arc<RwLock<EntityRefCounts>>,
 39}
 40
 41struct EntityRefCounts {
 42    counts: SlotMap<EntityId, AtomicUsize>,
 43    dropped_entity_ids: Vec<EntityId>,
 44    #[cfg(any(test, feature = "test-support"))]
 45    leak_detector: LeakDetector,
 46}
 47
 48impl EntityMap {
 49    pub fn new() -> Self {
 50        Self {
 51            entities: SecondaryMap::new(),
 52            ref_counts: Arc::new(RwLock::new(EntityRefCounts {
 53                counts: SlotMap::with_key(),
 54                dropped_entity_ids: Vec::new(),
 55                #[cfg(any(test, feature = "test-support"))]
 56                leak_detector: LeakDetector {
 57                    next_handle_id: 0,
 58                    entity_handles: HashMap::default(),
 59                },
 60            })),
 61        }
 62    }
 63
 64    /// Reserve a slot for an entity, which you can subsequently use with `insert`.
 65    pub fn reserve<T: 'static>(&self) -> Slot<T> {
 66        let id = self.ref_counts.write().counts.insert(1.into());
 67        Slot(Model::new(id, Arc::downgrade(&self.ref_counts)))
 68    }
 69
 70    /// Insert an entity into a slot obtained by calling `reserve`.
 71    pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Model<T>
 72    where
 73        T: 'static,
 74    {
 75        let model = slot.0;
 76        self.entities.insert(model.entity_id, Box::new(entity));
 77        model
 78    }
 79
 80    /// Move an entity to the stack.
 81    #[track_caller]
 82    pub fn lease<'a, T>(&mut self, model: &'a Model<T>) -> Lease<'a, T> {
 83        self.assert_valid_context(model);
 84        let entity = Some(self.entities.remove(model.entity_id).unwrap_or_else(|| {
 85            panic!(
 86                "Circular entity lease of {}. Is it already being updated?",
 87                std::any::type_name::<T>()
 88            )
 89        }));
 90        Lease {
 91            model,
 92            entity,
 93            entity_type: PhantomData,
 94        }
 95    }
 96
 97    /// Return an entity after moving it to the stack.
 98    pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
 99        self.entities
100            .insert(lease.model.entity_id, lease.entity.take().unwrap());
101    }
102
103    pub fn read<T: 'static>(&self, model: &Model<T>) -> &T {
104        self.assert_valid_context(model);
105        self.entities[model.entity_id].downcast_ref().unwrap()
106    }
107
108    fn assert_valid_context(&self, model: &AnyModel) {
109        debug_assert!(
110            Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)),
111            "used a model with the wrong context"
112        );
113    }
114
115    pub fn take_dropped(&mut self) -> Vec<(EntityId, Box<dyn Any>)> {
116        let mut ref_counts = self.ref_counts.write();
117        let dropped_entity_ids = mem::take(&mut ref_counts.dropped_entity_ids);
118
119        dropped_entity_ids
120            .into_iter()
121            .map(|entity_id| {
122                let count = ref_counts.counts.remove(entity_id).unwrap();
123                debug_assert_eq!(
124                    count.load(SeqCst),
125                    0,
126                    "dropped an entity that was referenced"
127                );
128                (entity_id, self.entities.remove(entity_id).unwrap())
129            })
130            .collect()
131    }
132}
133
134pub struct Lease<'a, T> {
135    entity: Option<Box<dyn Any>>,
136    pub model: &'a Model<T>,
137    entity_type: PhantomData<T>,
138}
139
140impl<'a, T: 'static> core::ops::Deref for Lease<'a, T> {
141    type Target = T;
142
143    fn deref(&self) -> &Self::Target {
144        self.entity.as_ref().unwrap().downcast_ref().unwrap()
145    }
146}
147
148impl<'a, T: 'static> core::ops::DerefMut for Lease<'a, T> {
149    fn deref_mut(&mut self) -> &mut Self::Target {
150        self.entity.as_mut().unwrap().downcast_mut().unwrap()
151    }
152}
153
154impl<'a, T> Drop for Lease<'a, T> {
155    fn drop(&mut self) {
156        if self.entity.is_some() && !panicking() {
157            panic!("Leases must be ended with EntityMap::end_lease")
158        }
159    }
160}
161
162#[derive(Deref, DerefMut)]
163pub struct Slot<T>(Model<T>);
164
165pub struct AnyModel {
166    pub(crate) entity_id: EntityId,
167    pub(crate) entity_type: TypeId,
168    entity_map: Weak<RwLock<EntityRefCounts>>,
169    #[cfg(any(test, feature = "test-support"))]
170    handle_id: HandleId,
171}
172
173impl AnyModel {
174    fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
175        Self {
176            entity_id: id,
177            entity_type,
178            entity_map: entity_map.clone(),
179            #[cfg(any(test, feature = "test-support"))]
180            handle_id: entity_map
181                .upgrade()
182                .unwrap()
183                .write()
184                .leak_detector
185                .handle_created(id),
186        }
187    }
188
189    pub fn entity_id(&self) -> EntityId {
190        self.entity_id
191    }
192
193    pub fn entity_type(&self) -> TypeId {
194        self.entity_type
195    }
196
197    pub fn downgrade(&self) -> AnyWeakModel {
198        AnyWeakModel {
199            entity_id: self.entity_id,
200            entity_type: self.entity_type,
201            entity_ref_counts: self.entity_map.clone(),
202        }
203    }
204
205    pub fn downcast<T: 'static>(self) -> Result<Model<T>, AnyModel> {
206        if TypeId::of::<T>() == self.entity_type {
207            Ok(Model {
208                any_model: self,
209                entity_type: PhantomData,
210            })
211        } else {
212            Err(self)
213        }
214    }
215}
216
217impl Clone for AnyModel {
218    fn clone(&self) -> Self {
219        if let Some(entity_map) = self.entity_map.upgrade() {
220            let entity_map = entity_map.read();
221            let count = entity_map
222                .counts
223                .get(self.entity_id)
224                .expect("detected over-release of a model");
225            let prev_count = count.fetch_add(1, SeqCst);
226            assert_ne!(prev_count, 0, "Detected over-release of a model.");
227        }
228
229        let this = Self {
230            entity_id: self.entity_id,
231            entity_type: self.entity_type,
232            entity_map: self.entity_map.clone(),
233            #[cfg(any(test, feature = "test-support"))]
234            handle_id: self
235                .entity_map
236                .upgrade()
237                .unwrap()
238                .write()
239                .leak_detector
240                .handle_created(self.entity_id),
241        };
242        this
243    }
244}
245
246impl Drop for AnyModel {
247    fn drop(&mut self) {
248        if let Some(entity_map) = self.entity_map.upgrade() {
249            let entity_map = entity_map.upgradable_read();
250            let count = entity_map
251                .counts
252                .get(self.entity_id)
253                .expect("detected over-release of a handle.");
254            let prev_count = count.fetch_sub(1, SeqCst);
255            assert_ne!(prev_count, 0, "Detected over-release of a model.");
256            if prev_count == 1 {
257                // We were the last reference to this entity, so we can remove it.
258                let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
259                entity_map.dropped_entity_ids.push(self.entity_id);
260            }
261        }
262
263        #[cfg(any(test, feature = "test-support"))]
264        if let Some(entity_map) = self.entity_map.upgrade() {
265            entity_map
266                .write()
267                .leak_detector
268                .handle_dropped(self.entity_id, self.handle_id)
269        }
270    }
271}
272
273impl<T> From<Model<T>> for AnyModel {
274    fn from(model: Model<T>) -> Self {
275        model.any_model
276    }
277}
278
279impl Hash for AnyModel {
280    fn hash<H: Hasher>(&self, state: &mut H) {
281        self.entity_id.hash(state);
282    }
283}
284
285impl PartialEq for AnyModel {
286    fn eq(&self, other: &Self) -> bool {
287        self.entity_id == other.entity_id
288    }
289}
290
291impl Eq for AnyModel {}
292
293impl std::fmt::Debug for AnyModel {
294    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
295        f.debug_struct("AnyModel")
296            .field("entity_id", &self.entity_id.as_u64())
297            .finish()
298    }
299}
300
301#[derive(Deref, DerefMut)]
302pub struct Model<T> {
303    #[deref]
304    #[deref_mut]
305    pub(crate) any_model: AnyModel,
306    pub(crate) entity_type: PhantomData<T>,
307}
308
309unsafe impl<T> Send for Model<T> {}
310unsafe impl<T> Sync for Model<T> {}
311impl<T> Sealed for Model<T> {}
312
313impl<T: 'static> Entity<T> for Model<T> {
314    type Weak = WeakModel<T>;
315
316    fn entity_id(&self) -> EntityId {
317        self.any_model.entity_id
318    }
319
320    fn downgrade(&self) -> Self::Weak {
321        WeakModel {
322            any_model: self.any_model.downgrade(),
323            entity_type: self.entity_type,
324        }
325    }
326
327    fn upgrade_from(weak: &Self::Weak) -> Option<Self>
328    where
329        Self: Sized,
330    {
331        Some(Model {
332            any_model: weak.any_model.upgrade()?,
333            entity_type: weak.entity_type,
334        })
335    }
336}
337
338impl<T: 'static> Model<T> {
339    fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
340    where
341        T: 'static,
342    {
343        Self {
344            any_model: AnyModel::new(id, TypeId::of::<T>(), entity_map),
345            entity_type: PhantomData,
346        }
347    }
348
349    /// Downgrade the this to a weak model reference
350    pub fn downgrade(&self) -> WeakModel<T> {
351        // Delegate to the trait implementation to keep behavior in one place.
352        // This method was included to improve method resolution in the presence of
353        // the Model's deref
354        Entity::downgrade(self)
355    }
356
357    /// Convert this into a dynamically typed model.
358    pub fn into_any(self) -> AnyModel {
359        self.any_model
360    }
361
362    pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
363        cx.entities.read(self)
364    }
365
366    pub fn read_with<R, C: Context>(
367        &self,
368        cx: &C,
369        f: impl FnOnce(&T, &AppContext) -> R,
370    ) -> C::Result<R> {
371        cx.read_model(self, f)
372    }
373
374    /// Update the entity referenced by this model with the given function.
375    ///
376    /// The update function receives a context appropriate for its environment.
377    /// When updating in an `AppContext`, it receives a `ModelContext`.
378    /// When updating an a `WindowContext`, it receives a `ViewContext`.
379    pub fn update<C, R>(
380        &self,
381        cx: &mut C,
382        update: impl FnOnce(&mut T, &mut ModelContext<'_, T>) -> R,
383    ) -> C::Result<R>
384    where
385        C: Context,
386    {
387        cx.update_model(self, update)
388    }
389}
390
391impl<T> Clone for Model<T> {
392    fn clone(&self) -> Self {
393        Self {
394            any_model: self.any_model.clone(),
395            entity_type: self.entity_type,
396        }
397    }
398}
399
400impl<T> std::fmt::Debug for Model<T> {
401    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402        write!(
403            f,
404            "Model {{ entity_id: {:?}, entity_type: {:?} }}",
405            self.any_model.entity_id,
406            type_name::<T>()
407        )
408    }
409}
410
411impl<T> Hash for Model<T> {
412    fn hash<H: Hasher>(&self, state: &mut H) {
413        self.any_model.hash(state);
414    }
415}
416
417impl<T> PartialEq for Model<T> {
418    fn eq(&self, other: &Self) -> bool {
419        self.any_model == other.any_model
420    }
421}
422
423impl<T> Eq for Model<T> {}
424
425impl<T> PartialEq<WeakModel<T>> for Model<T> {
426    fn eq(&self, other: &WeakModel<T>) -> bool {
427        self.any_model.entity_id() == other.entity_id()
428    }
429}
430
431#[derive(Clone)]
432pub struct AnyWeakModel {
433    pub(crate) entity_id: EntityId,
434    entity_type: TypeId,
435    entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
436}
437
438impl AnyWeakModel {
439    pub fn entity_id(&self) -> EntityId {
440        self.entity_id
441    }
442
443    pub fn is_upgradable(&self) -> bool {
444        let ref_count = self
445            .entity_ref_counts
446            .upgrade()
447            .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
448            .unwrap_or(0);
449        ref_count > 0
450    }
451
452    pub fn upgrade(&self) -> Option<AnyModel> {
453        let ref_counts = &self.entity_ref_counts.upgrade()?;
454        let ref_counts = ref_counts.read();
455        let ref_count = ref_counts.counts.get(self.entity_id)?;
456
457        // entity_id is in dropped_entity_ids
458        if ref_count.load(SeqCst) == 0 {
459            return None;
460        }
461        ref_count.fetch_add(1, SeqCst);
462        drop(ref_counts);
463
464        Some(AnyModel {
465            entity_id: self.entity_id,
466            entity_type: self.entity_type,
467            entity_map: self.entity_ref_counts.clone(),
468            #[cfg(any(test, feature = "test-support"))]
469            handle_id: self
470                .entity_ref_counts
471                .upgrade()
472                .unwrap()
473                .write()
474                .leak_detector
475                .handle_created(self.entity_id),
476        })
477    }
478
479    #[cfg(any(test, feature = "test-support"))]
480    pub fn assert_dropped(&self) {
481        self.entity_ref_counts
482            .upgrade()
483            .unwrap()
484            .write()
485            .leak_detector
486            .assert_dropped(self.entity_id);
487
488        if self
489            .entity_ref_counts
490            .upgrade()
491            .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
492            .is_some()
493        {
494            panic!(
495                "entity was recently dropped but resources are retained until the end of the effect cycle."
496            )
497        }
498    }
499}
500
501impl<T> From<WeakModel<T>> for AnyWeakModel {
502    fn from(model: WeakModel<T>) -> Self {
503        model.any_model
504    }
505}
506
507impl Hash for AnyWeakModel {
508    fn hash<H: Hasher>(&self, state: &mut H) {
509        self.entity_id.hash(state);
510    }
511}
512
513impl PartialEq for AnyWeakModel {
514    fn eq(&self, other: &Self) -> bool {
515        self.entity_id == other.entity_id
516    }
517}
518
519impl Eq for AnyWeakModel {}
520
521#[derive(Deref, DerefMut)]
522pub struct WeakModel<T> {
523    #[deref]
524    #[deref_mut]
525    any_model: AnyWeakModel,
526    entity_type: PhantomData<T>,
527}
528
529unsafe impl<T> Send for WeakModel<T> {}
530unsafe impl<T> Sync for WeakModel<T> {}
531
532impl<T> Clone for WeakModel<T> {
533    fn clone(&self) -> Self {
534        Self {
535            any_model: self.any_model.clone(),
536            entity_type: self.entity_type,
537        }
538    }
539}
540
541impl<T: 'static> WeakModel<T> {
542    /// Upgrade this weak model reference into a strong model reference
543    pub fn upgrade(&self) -> Option<Model<T>> {
544        // Delegate to the trait implementation to keep behavior in one place.
545        Model::upgrade_from(self)
546    }
547
548    /// Update the entity referenced by this model with the given function if
549    /// the referenced entity still exists. Returns an error if the entity has
550    /// been released.
551    pub fn update<C, R>(
552        &self,
553        cx: &mut C,
554        update: impl FnOnce(&mut T, &mut ModelContext<'_, T>) -> R,
555    ) -> Result<R>
556    where
557        C: Context,
558        Result<C::Result<R>>: crate::Flatten<R>,
559    {
560        crate::Flatten::flatten(
561            self.upgrade()
562                .ok_or_else(|| anyhow!("entity release"))
563                .map(|this| cx.update_model(&this, update)),
564        )
565    }
566
567    /// Reads the entity referenced by this model with the given function if
568    /// the referenced entity still exists. Returns an error if the entity has
569    /// been released.
570    pub fn read_with<C, R>(&self, cx: &C, read: impl FnOnce(&T, &AppContext) -> R) -> Result<R>
571    where
572        C: Context,
573        Result<C::Result<R>>: crate::Flatten<R>,
574    {
575        crate::Flatten::flatten(
576            self.upgrade()
577                .ok_or_else(|| anyhow!("entity release"))
578                .map(|this| cx.read_model(&this, read)),
579        )
580    }
581}
582
583impl<T> Hash for WeakModel<T> {
584    fn hash<H: Hasher>(&self, state: &mut H) {
585        self.any_model.hash(state);
586    }
587}
588
589impl<T> PartialEq for WeakModel<T> {
590    fn eq(&self, other: &Self) -> bool {
591        self.any_model == other.any_model
592    }
593}
594
595impl<T> Eq for WeakModel<T> {}
596
597impl<T> PartialEq<Model<T>> for WeakModel<T> {
598    fn eq(&self, other: &Model<T>) -> bool {
599        self.entity_id() == other.any_model.entity_id()
600    }
601}
602
603#[cfg(any(test, feature = "test-support"))]
604lazy_static::lazy_static! {
605    static ref LEAK_BACKTRACE: bool =
606        std::env::var("LEAK_BACKTRACE").map_or(false, |b| !b.is_empty());
607}
608
609#[cfg(any(test, feature = "test-support"))]
610#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)]
611pub struct HandleId {
612    id: u64, // id of the handle itself, not the pointed at object
613}
614
615#[cfg(any(test, feature = "test-support"))]
616pub struct LeakDetector {
617    next_handle_id: u64,
618    entity_handles: HashMap<EntityId, HashMap<HandleId, Option<backtrace::Backtrace>>>,
619}
620
621#[cfg(any(test, feature = "test-support"))]
622impl LeakDetector {
623    #[track_caller]
624    pub fn handle_created(&mut self, entity_id: EntityId) -> HandleId {
625        let id = util::post_inc(&mut self.next_handle_id);
626        let handle_id = HandleId { id };
627        let handles = self.entity_handles.entry(entity_id).or_default();
628        handles.insert(
629            handle_id,
630            LEAK_BACKTRACE.then(|| backtrace::Backtrace::new_unresolved()),
631        );
632        handle_id
633    }
634
635    pub fn handle_dropped(&mut self, entity_id: EntityId, handle_id: HandleId) {
636        let handles = self.entity_handles.entry(entity_id).or_default();
637        handles.remove(&handle_id);
638    }
639
640    pub fn assert_dropped(&mut self, entity_id: EntityId) {
641        let handles = self.entity_handles.entry(entity_id).or_default();
642        if !handles.is_empty() {
643            for (_, backtrace) in handles {
644                if let Some(mut backtrace) = backtrace.take() {
645                    backtrace.resolve();
646                    eprintln!("Leaked handle: {:#?}", backtrace);
647                } else {
648                    eprintln!("Leaked handle: export LEAK_BACKTRACE to find allocation site");
649                }
650            }
651            panic!();
652        }
653    }
654}
655
656#[cfg(test)]
657mod test {
658    use crate::EntityMap;
659
660    struct TestEntity {
661        pub i: i32,
662    }
663
664    #[test]
665    fn test_entity_map_slot_assignment_before_cleanup() {
666        // Tests that slots are not re-used before take_dropped.
667        let mut entity_map = EntityMap::new();
668
669        let slot = entity_map.reserve::<TestEntity>();
670        entity_map.insert(slot, TestEntity { i: 1 });
671
672        let slot = entity_map.reserve::<TestEntity>();
673        entity_map.insert(slot, TestEntity { i: 2 });
674
675        let dropped = entity_map.take_dropped();
676        assert_eq!(dropped.len(), 2);
677
678        assert_eq!(
679            dropped
680                .into_iter()
681                .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
682                .collect::<Vec<i32>>(),
683            vec![1, 2],
684        );
685    }
686
687    #[test]
688    fn test_entity_map_weak_upgrade_before_cleanup() {
689        // Tests that weak handles are not upgraded before take_dropped
690        let mut entity_map = EntityMap::new();
691
692        let slot = entity_map.reserve::<TestEntity>();
693        let handle = entity_map.insert(slot, TestEntity { i: 1 });
694        let weak = handle.downgrade();
695        drop(handle);
696
697        let strong = weak.upgrade();
698        assert_eq!(strong, None);
699
700        let dropped = entity_map.take_dropped();
701        assert_eq!(dropped.len(), 1);
702
703        assert_eq!(
704            dropped
705                .into_iter()
706                .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
707                .collect::<Vec<i32>>(),
708            vec![1],
709        );
710    }
711}