1use crate::{seal::Sealed, AppContext, Context, Entity, ModelContext};
2use anyhow::{anyhow, Result};
3use derive_more::{Deref, DerefMut};
4use parking_lot::{RwLock, RwLockUpgradableReadGuard};
5use slotmap::{SecondaryMap, SlotMap};
6use std::{
7 any::{type_name, Any, TypeId},
8 fmt::{self, Display},
9 hash::{Hash, Hasher},
10 marker::PhantomData,
11 mem,
12 sync::{
13 atomic::{AtomicUsize, Ordering::SeqCst},
14 Arc, Weak,
15 },
16 thread::panicking,
17};
18
19#[cfg(any(test, feature = "test-support"))]
20use collections::HashMap;
21
22slotmap::new_key_type! { pub struct EntityId; }
23
24impl EntityId {
25 pub fn as_u64(self) -> u64 {
26 self.0.as_ffi()
27 }
28}
29
30impl Display for EntityId {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 write!(f, "{}", self.as_u64())
33 }
34}
35
36pub(crate) struct EntityMap {
37 entities: SecondaryMap<EntityId, Box<dyn Any>>,
38 ref_counts: Arc<RwLock<EntityRefCounts>>,
39}
40
41struct EntityRefCounts {
42 counts: SlotMap<EntityId, AtomicUsize>,
43 dropped_entity_ids: Vec<EntityId>,
44 #[cfg(any(test, feature = "test-support"))]
45 leak_detector: LeakDetector,
46}
47
48impl EntityMap {
49 pub fn new() -> Self {
50 Self {
51 entities: SecondaryMap::new(),
52 ref_counts: Arc::new(RwLock::new(EntityRefCounts {
53 counts: SlotMap::with_key(),
54 dropped_entity_ids: Vec::new(),
55 #[cfg(any(test, feature = "test-support"))]
56 leak_detector: LeakDetector {
57 next_handle_id: 0,
58 entity_handles: HashMap::default(),
59 },
60 })),
61 }
62 }
63
64 /// Reserve a slot for an entity, which you can subsequently use with `insert`.
65 pub fn reserve<T: 'static>(&self) -> Slot<T> {
66 let id = self.ref_counts.write().counts.insert(1.into());
67 Slot(Model::new(id, Arc::downgrade(&self.ref_counts)))
68 }
69
70 /// Insert an entity into a slot obtained by calling `reserve`.
71 pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Model<T>
72 where
73 T: 'static,
74 {
75 let model = slot.0;
76 self.entities.insert(model.entity_id, Box::new(entity));
77 model
78 }
79
80 /// Move an entity to the stack.
81 #[track_caller]
82 pub fn lease<'a, T>(&mut self, model: &'a Model<T>) -> Lease<'a, T> {
83 self.assert_valid_context(model);
84 let entity = Some(self.entities.remove(model.entity_id).unwrap_or_else(|| {
85 panic!(
86 "Circular entity lease of {}. Is it already being updated?",
87 std::any::type_name::<T>()
88 )
89 }));
90 Lease {
91 model,
92 entity,
93 entity_type: PhantomData,
94 }
95 }
96
97 /// Return an entity after moving it to the stack.
98 pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
99 self.entities
100 .insert(lease.model.entity_id, lease.entity.take().unwrap());
101 }
102
103 pub fn read<T: 'static>(&self, model: &Model<T>) -> &T {
104 self.assert_valid_context(model);
105 self.entities[model.entity_id].downcast_ref().unwrap()
106 }
107
108 fn assert_valid_context(&self, model: &AnyModel) {
109 debug_assert!(
110 Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)),
111 "used a model with the wrong context"
112 );
113 }
114
115 pub fn take_dropped(&mut self) -> Vec<(EntityId, Box<dyn Any>)> {
116 let mut ref_counts = self.ref_counts.write();
117 let dropped_entity_ids = mem::take(&mut ref_counts.dropped_entity_ids);
118
119 dropped_entity_ids
120 .into_iter()
121 .map(|entity_id| {
122 let count = ref_counts.counts.remove(entity_id).unwrap();
123 debug_assert_eq!(
124 count.load(SeqCst),
125 0,
126 "dropped an entity that was referenced"
127 );
128 (entity_id, self.entities.remove(entity_id).unwrap())
129 })
130 .collect()
131 }
132}
133
134pub struct Lease<'a, T> {
135 entity: Option<Box<dyn Any>>,
136 pub model: &'a Model<T>,
137 entity_type: PhantomData<T>,
138}
139
140impl<'a, T: 'static> core::ops::Deref for Lease<'a, T> {
141 type Target = T;
142
143 fn deref(&self) -> &Self::Target {
144 self.entity.as_ref().unwrap().downcast_ref().unwrap()
145 }
146}
147
148impl<'a, T: 'static> core::ops::DerefMut for Lease<'a, T> {
149 fn deref_mut(&mut self) -> &mut Self::Target {
150 self.entity.as_mut().unwrap().downcast_mut().unwrap()
151 }
152}
153
154impl<'a, T> Drop for Lease<'a, T> {
155 fn drop(&mut self) {
156 if self.entity.is_some() && !panicking() {
157 panic!("Leases must be ended with EntityMap::end_lease")
158 }
159 }
160}
161
162#[derive(Deref, DerefMut)]
163pub struct Slot<T>(Model<T>);
164
165pub struct AnyModel {
166 pub(crate) entity_id: EntityId,
167 pub(crate) entity_type: TypeId,
168 entity_map: Weak<RwLock<EntityRefCounts>>,
169 #[cfg(any(test, feature = "test-support"))]
170 handle_id: HandleId,
171}
172
173impl AnyModel {
174 fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
175 Self {
176 entity_id: id,
177 entity_type,
178 entity_map: entity_map.clone(),
179 #[cfg(any(test, feature = "test-support"))]
180 handle_id: entity_map
181 .upgrade()
182 .unwrap()
183 .write()
184 .leak_detector
185 .handle_created(id),
186 }
187 }
188
189 pub fn entity_id(&self) -> EntityId {
190 self.entity_id
191 }
192
193 pub fn entity_type(&self) -> TypeId {
194 self.entity_type
195 }
196
197 pub fn downgrade(&self) -> AnyWeakModel {
198 AnyWeakModel {
199 entity_id: self.entity_id,
200 entity_type: self.entity_type,
201 entity_ref_counts: self.entity_map.clone(),
202 }
203 }
204
205 pub fn downcast<T: 'static>(self) -> Result<Model<T>, AnyModel> {
206 if TypeId::of::<T>() == self.entity_type {
207 Ok(Model {
208 any_model: self,
209 entity_type: PhantomData,
210 })
211 } else {
212 Err(self)
213 }
214 }
215}
216
217impl Clone for AnyModel {
218 fn clone(&self) -> Self {
219 if let Some(entity_map) = self.entity_map.upgrade() {
220 let entity_map = entity_map.read();
221 let count = entity_map
222 .counts
223 .get(self.entity_id)
224 .expect("detected over-release of a model");
225 let prev_count = count.fetch_add(1, SeqCst);
226 assert_ne!(prev_count, 0, "Detected over-release of a model.");
227 }
228
229 let this = Self {
230 entity_id: self.entity_id,
231 entity_type: self.entity_type,
232 entity_map: self.entity_map.clone(),
233 #[cfg(any(test, feature = "test-support"))]
234 handle_id: self
235 .entity_map
236 .upgrade()
237 .unwrap()
238 .write()
239 .leak_detector
240 .handle_created(self.entity_id),
241 };
242 this
243 }
244}
245
246impl Drop for AnyModel {
247 fn drop(&mut self) {
248 if let Some(entity_map) = self.entity_map.upgrade() {
249 let entity_map = entity_map.upgradable_read();
250 let count = entity_map
251 .counts
252 .get(self.entity_id)
253 .expect("detected over-release of a handle.");
254 let prev_count = count.fetch_sub(1, SeqCst);
255 assert_ne!(prev_count, 0, "Detected over-release of a model.");
256 if prev_count == 1 {
257 // We were the last reference to this entity, so we can remove it.
258 let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
259 entity_map.dropped_entity_ids.push(self.entity_id);
260 }
261 }
262
263 #[cfg(any(test, feature = "test-support"))]
264 if let Some(entity_map) = self.entity_map.upgrade() {
265 entity_map
266 .write()
267 .leak_detector
268 .handle_dropped(self.entity_id, self.handle_id)
269 }
270 }
271}
272
273impl<T> From<Model<T>> for AnyModel {
274 fn from(model: Model<T>) -> Self {
275 model.any_model
276 }
277}
278
279impl Hash for AnyModel {
280 fn hash<H: Hasher>(&self, state: &mut H) {
281 self.entity_id.hash(state);
282 }
283}
284
285impl PartialEq for AnyModel {
286 fn eq(&self, other: &Self) -> bool {
287 self.entity_id == other.entity_id
288 }
289}
290
291impl Eq for AnyModel {}
292
293impl std::fmt::Debug for AnyModel {
294 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
295 f.debug_struct("AnyModel")
296 .field("entity_id", &self.entity_id.as_u64())
297 .finish()
298 }
299}
300
301#[derive(Deref, DerefMut)]
302pub struct Model<T> {
303 #[deref]
304 #[deref_mut]
305 pub(crate) any_model: AnyModel,
306 pub(crate) entity_type: PhantomData<T>,
307}
308
309unsafe impl<T> Send for Model<T> {}
310unsafe impl<T> Sync for Model<T> {}
311impl<T> Sealed for Model<T> {}
312
313impl<T: 'static> Entity<T> for Model<T> {
314 type Weak = WeakModel<T>;
315
316 fn entity_id(&self) -> EntityId {
317 self.any_model.entity_id
318 }
319
320 fn downgrade(&self) -> Self::Weak {
321 WeakModel {
322 any_model: self.any_model.downgrade(),
323 entity_type: self.entity_type,
324 }
325 }
326
327 fn upgrade_from(weak: &Self::Weak) -> Option<Self>
328 where
329 Self: Sized,
330 {
331 Some(Model {
332 any_model: weak.any_model.upgrade()?,
333 entity_type: weak.entity_type,
334 })
335 }
336}
337
338impl<T: 'static> Model<T> {
339 fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
340 where
341 T: 'static,
342 {
343 Self {
344 any_model: AnyModel::new(id, TypeId::of::<T>(), entity_map),
345 entity_type: PhantomData,
346 }
347 }
348
349 /// Downgrade the this to a weak model reference
350 pub fn downgrade(&self) -> WeakModel<T> {
351 // Delegate to the trait implementation to keep behavior in one place.
352 // This method was included to improve method resolution in the presence of
353 // the Model's deref
354 Entity::downgrade(self)
355 }
356
357 /// Convert this into a dynamically typed model.
358 pub fn into_any(self) -> AnyModel {
359 self.any_model
360 }
361
362 pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
363 cx.entities.read(self)
364 }
365
366 pub fn read_with<R, C: Context>(
367 &self,
368 cx: &C,
369 f: impl FnOnce(&T, &AppContext) -> R,
370 ) -> C::Result<R> {
371 cx.read_model(self, f)
372 }
373
374 /// Update the entity referenced by this model with the given function.
375 ///
376 /// The update function receives a context appropriate for its environment.
377 /// When updating in an `AppContext`, it receives a `ModelContext`.
378 /// When updating an a `WindowContext`, it receives a `ViewContext`.
379 pub fn update<C, R>(
380 &self,
381 cx: &mut C,
382 update: impl FnOnce(&mut T, &mut ModelContext<'_, T>) -> R,
383 ) -> C::Result<R>
384 where
385 C: Context,
386 {
387 cx.update_model(self, update)
388 }
389}
390
391impl<T> Clone for Model<T> {
392 fn clone(&self) -> Self {
393 Self {
394 any_model: self.any_model.clone(),
395 entity_type: self.entity_type,
396 }
397 }
398}
399
400impl<T> std::fmt::Debug for Model<T> {
401 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402 write!(
403 f,
404 "Model {{ entity_id: {:?}, entity_type: {:?} }}",
405 self.any_model.entity_id,
406 type_name::<T>()
407 )
408 }
409}
410
411impl<T> Hash for Model<T> {
412 fn hash<H: Hasher>(&self, state: &mut H) {
413 self.any_model.hash(state);
414 }
415}
416
417impl<T> PartialEq for Model<T> {
418 fn eq(&self, other: &Self) -> bool {
419 self.any_model == other.any_model
420 }
421}
422
423impl<T> Eq for Model<T> {}
424
425impl<T> PartialEq<WeakModel<T>> for Model<T> {
426 fn eq(&self, other: &WeakModel<T>) -> bool {
427 self.any_model.entity_id() == other.entity_id()
428 }
429}
430
431#[derive(Clone)]
432pub struct AnyWeakModel {
433 pub(crate) entity_id: EntityId,
434 entity_type: TypeId,
435 entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
436}
437
438impl AnyWeakModel {
439 pub fn entity_id(&self) -> EntityId {
440 self.entity_id
441 }
442
443 pub fn is_upgradable(&self) -> bool {
444 let ref_count = self
445 .entity_ref_counts
446 .upgrade()
447 .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
448 .unwrap_or(0);
449 ref_count > 0
450 }
451
452 pub fn upgrade(&self) -> Option<AnyModel> {
453 let ref_counts = &self.entity_ref_counts.upgrade()?;
454 let ref_counts = ref_counts.read();
455 let ref_count = ref_counts.counts.get(self.entity_id)?;
456
457 // entity_id is in dropped_entity_ids
458 if ref_count.load(SeqCst) == 0 {
459 return None;
460 }
461 ref_count.fetch_add(1, SeqCst);
462 drop(ref_counts);
463
464 Some(AnyModel {
465 entity_id: self.entity_id,
466 entity_type: self.entity_type,
467 entity_map: self.entity_ref_counts.clone(),
468 #[cfg(any(test, feature = "test-support"))]
469 handle_id: self
470 .entity_ref_counts
471 .upgrade()
472 .unwrap()
473 .write()
474 .leak_detector
475 .handle_created(self.entity_id),
476 })
477 }
478
479 #[cfg(any(test, feature = "test-support"))]
480 pub fn assert_dropped(&self) {
481 self.entity_ref_counts
482 .upgrade()
483 .unwrap()
484 .write()
485 .leak_detector
486 .assert_dropped(self.entity_id);
487
488 if self
489 .entity_ref_counts
490 .upgrade()
491 .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
492 .is_some()
493 {
494 panic!(
495 "entity was recently dropped but resources are retained until the end of the effect cycle."
496 )
497 }
498 }
499}
500
501impl<T> From<WeakModel<T>> for AnyWeakModel {
502 fn from(model: WeakModel<T>) -> Self {
503 model.any_model
504 }
505}
506
507impl Hash for AnyWeakModel {
508 fn hash<H: Hasher>(&self, state: &mut H) {
509 self.entity_id.hash(state);
510 }
511}
512
513impl PartialEq for AnyWeakModel {
514 fn eq(&self, other: &Self) -> bool {
515 self.entity_id == other.entity_id
516 }
517}
518
519impl Eq for AnyWeakModel {}
520
521#[derive(Deref, DerefMut)]
522pub struct WeakModel<T> {
523 #[deref]
524 #[deref_mut]
525 any_model: AnyWeakModel,
526 entity_type: PhantomData<T>,
527}
528
529unsafe impl<T> Send for WeakModel<T> {}
530unsafe impl<T> Sync for WeakModel<T> {}
531
532impl<T> Clone for WeakModel<T> {
533 fn clone(&self) -> Self {
534 Self {
535 any_model: self.any_model.clone(),
536 entity_type: self.entity_type,
537 }
538 }
539}
540
541impl<T: 'static> WeakModel<T> {
542 /// Upgrade this weak model reference into a strong model reference
543 pub fn upgrade(&self) -> Option<Model<T>> {
544 // Delegate to the trait implementation to keep behavior in one place.
545 Model::upgrade_from(self)
546 }
547
548 /// Update the entity referenced by this model with the given function if
549 /// the referenced entity still exists. Returns an error if the entity has
550 /// been released.
551 pub fn update<C, R>(
552 &self,
553 cx: &mut C,
554 update: impl FnOnce(&mut T, &mut ModelContext<'_, T>) -> R,
555 ) -> Result<R>
556 where
557 C: Context,
558 Result<C::Result<R>>: crate::Flatten<R>,
559 {
560 crate::Flatten::flatten(
561 self.upgrade()
562 .ok_or_else(|| anyhow!("entity release"))
563 .map(|this| cx.update_model(&this, update)),
564 )
565 }
566
567 /// Reads the entity referenced by this model with the given function if
568 /// the referenced entity still exists. Returns an error if the entity has
569 /// been released.
570 pub fn read_with<C, R>(&self, cx: &C, read: impl FnOnce(&T, &AppContext) -> R) -> Result<R>
571 where
572 C: Context,
573 Result<C::Result<R>>: crate::Flatten<R>,
574 {
575 crate::Flatten::flatten(
576 self.upgrade()
577 .ok_or_else(|| anyhow!("entity release"))
578 .map(|this| cx.read_model(&this, read)),
579 )
580 }
581}
582
583impl<T> Hash for WeakModel<T> {
584 fn hash<H: Hasher>(&self, state: &mut H) {
585 self.any_model.hash(state);
586 }
587}
588
589impl<T> PartialEq for WeakModel<T> {
590 fn eq(&self, other: &Self) -> bool {
591 self.any_model == other.any_model
592 }
593}
594
595impl<T> Eq for WeakModel<T> {}
596
597impl<T> PartialEq<Model<T>> for WeakModel<T> {
598 fn eq(&self, other: &Model<T>) -> bool {
599 self.entity_id() == other.any_model.entity_id()
600 }
601}
602
603#[cfg(any(test, feature = "test-support"))]
604lazy_static::lazy_static! {
605 static ref LEAK_BACKTRACE: bool =
606 std::env::var("LEAK_BACKTRACE").map_or(false, |b| !b.is_empty());
607}
608
609#[cfg(any(test, feature = "test-support"))]
610#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)]
611pub struct HandleId {
612 id: u64, // id of the handle itself, not the pointed at object
613}
614
615#[cfg(any(test, feature = "test-support"))]
616pub struct LeakDetector {
617 next_handle_id: u64,
618 entity_handles: HashMap<EntityId, HashMap<HandleId, Option<backtrace::Backtrace>>>,
619}
620
621#[cfg(any(test, feature = "test-support"))]
622impl LeakDetector {
623 #[track_caller]
624 pub fn handle_created(&mut self, entity_id: EntityId) -> HandleId {
625 let id = util::post_inc(&mut self.next_handle_id);
626 let handle_id = HandleId { id };
627 let handles = self.entity_handles.entry(entity_id).or_default();
628 handles.insert(
629 handle_id,
630 LEAK_BACKTRACE.then(|| backtrace::Backtrace::new_unresolved()),
631 );
632 handle_id
633 }
634
635 pub fn handle_dropped(&mut self, entity_id: EntityId, handle_id: HandleId) {
636 let handles = self.entity_handles.entry(entity_id).or_default();
637 handles.remove(&handle_id);
638 }
639
640 pub fn assert_dropped(&mut self, entity_id: EntityId) {
641 let handles = self.entity_handles.entry(entity_id).or_default();
642 if !handles.is_empty() {
643 for (_, backtrace) in handles {
644 if let Some(mut backtrace) = backtrace.take() {
645 backtrace.resolve();
646 eprintln!("Leaked handle: {:#?}", backtrace);
647 } else {
648 eprintln!("Leaked handle: export LEAK_BACKTRACE to find allocation site");
649 }
650 }
651 panic!();
652 }
653 }
654}
655
656#[cfg(test)]
657mod test {
658 use crate::EntityMap;
659
660 struct TestEntity {
661 pub i: i32,
662 }
663
664 #[test]
665 fn test_entity_map_slot_assignment_before_cleanup() {
666 // Tests that slots are not re-used before take_dropped.
667 let mut entity_map = EntityMap::new();
668
669 let slot = entity_map.reserve::<TestEntity>();
670 entity_map.insert(slot, TestEntity { i: 1 });
671
672 let slot = entity_map.reserve::<TestEntity>();
673 entity_map.insert(slot, TestEntity { i: 2 });
674
675 let dropped = entity_map.take_dropped();
676 assert_eq!(dropped.len(), 2);
677
678 assert_eq!(
679 dropped
680 .into_iter()
681 .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
682 .collect::<Vec<i32>>(),
683 vec![1, 2],
684 );
685 }
686
687 #[test]
688 fn test_entity_map_weak_upgrade_before_cleanup() {
689 // Tests that weak handles are not upgraded before take_dropped
690 let mut entity_map = EntityMap::new();
691
692 let slot = entity_map.reserve::<TestEntity>();
693 let handle = entity_map.insert(slot, TestEntity { i: 1 });
694 let weak = handle.downgrade();
695 drop(handle);
696
697 let strong = weak.upgrade();
698 assert_eq!(strong, None);
699
700 let dropped = entity_map.take_dropped();
701 assert_eq!(dropped.len(), 1);
702
703 assert_eq!(
704 dropped
705 .into_iter()
706 .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
707 .collect::<Vec<i32>>(),
708 vec![1],
709 );
710 }
711}