1use crate::{private::Sealed, AppContext, Context, Entity, ModelContext};
2use anyhow::{anyhow, Result};
3use collections::HashMap;
4use derive_more::{Deref, DerefMut};
5use lazy_static::lazy_static;
6use parking_lot::{RwLock, RwLockUpgradableReadGuard};
7use slotmap::{SecondaryMap, SlotMap};
8use std::{
9 any::{type_name, Any, TypeId},
10 fmt::{self, Display},
11 hash::{Hash, Hasher},
12 marker::PhantomData,
13 mem,
14 sync::{
15 atomic::{AtomicUsize, Ordering::SeqCst},
16 Arc, Weak,
17 },
18 thread::panicking,
19};
20
21slotmap::new_key_type! { pub struct EntityId; }
22
23impl EntityId {
24 pub fn as_u64(self) -> u64 {
25 self.0.as_ffi()
26 }
27}
28
29impl Display for EntityId {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 write!(f, "{}", self.as_u64())
32 }
33}
34
35pub(crate) struct EntityMap {
36 entities: SecondaryMap<EntityId, Box<dyn Any>>,
37 ref_counts: Arc<RwLock<EntityRefCounts>>,
38}
39
40struct EntityRefCounts {
41 counts: SlotMap<EntityId, AtomicUsize>,
42 dropped_entity_ids: Vec<EntityId>,
43 #[cfg(any(test, feature = "test-support"))]
44 leak_detector: LeakDetector,
45}
46
47impl EntityMap {
48 pub fn new() -> Self {
49 Self {
50 entities: SecondaryMap::new(),
51 ref_counts: Arc::new(RwLock::new(EntityRefCounts {
52 counts: SlotMap::with_key(),
53 dropped_entity_ids: Vec::new(),
54 #[cfg(any(test, feature = "test-support"))]
55 leak_detector: LeakDetector {
56 next_handle_id: 0,
57 entity_handles: HashMap::default(),
58 },
59 })),
60 }
61 }
62
63 /// Reserve a slot for an entity, which you can subsequently use with `insert`.
64 pub fn reserve<T: 'static>(&self) -> Slot<T> {
65 let id = self.ref_counts.write().counts.insert(1.into());
66 Slot(Model::new(id, Arc::downgrade(&self.ref_counts)))
67 }
68
69 /// Insert an entity into a slot obtained by calling `reserve`.
70 pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Model<T>
71 where
72 T: 'static,
73 {
74 let model = slot.0;
75 self.entities.insert(model.entity_id, Box::new(entity));
76 model
77 }
78
79 /// Move an entity to the stack.
80 #[track_caller]
81 pub fn lease<'a, T>(&mut self, model: &'a Model<T>) -> Lease<'a, T> {
82 self.assert_valid_context(model);
83 let entity = Some(self.entities.remove(model.entity_id).unwrap_or_else(|| {
84 panic!(
85 "Circular entity lease of {}. Is it already being updated?",
86 std::any::type_name::<T>()
87 )
88 }));
89 Lease {
90 model,
91 entity,
92 entity_type: PhantomData,
93 }
94 }
95
96 /// Return an entity after moving it to the stack.
97 pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
98 self.entities
99 .insert(lease.model.entity_id, lease.entity.take().unwrap());
100 }
101
102 pub fn read<T: 'static>(&self, model: &Model<T>) -> &T {
103 self.assert_valid_context(model);
104 self.entities[model.entity_id].downcast_ref().unwrap()
105 }
106
107 fn assert_valid_context(&self, model: &AnyModel) {
108 debug_assert!(
109 Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)),
110 "used a model with the wrong context"
111 );
112 }
113
114 pub fn take_dropped(&mut self) -> Vec<(EntityId, Box<dyn Any>)> {
115 let mut ref_counts = self.ref_counts.write();
116 let dropped_entity_ids = mem::take(&mut ref_counts.dropped_entity_ids);
117
118 dropped_entity_ids
119 .into_iter()
120 .map(|entity_id| {
121 let count = ref_counts.counts.remove(entity_id).unwrap();
122 debug_assert_eq!(
123 count.load(SeqCst),
124 0,
125 "dropped an entity that was referenced"
126 );
127 (entity_id, self.entities.remove(entity_id).unwrap())
128 })
129 .collect()
130 }
131}
132
133pub struct Lease<'a, T> {
134 entity: Option<Box<dyn Any>>,
135 pub model: &'a Model<T>,
136 entity_type: PhantomData<T>,
137}
138
139impl<'a, T: 'static> core::ops::Deref for Lease<'a, T> {
140 type Target = T;
141
142 fn deref(&self) -> &Self::Target {
143 self.entity.as_ref().unwrap().downcast_ref().unwrap()
144 }
145}
146
147impl<'a, T: 'static> core::ops::DerefMut for Lease<'a, T> {
148 fn deref_mut(&mut self) -> &mut Self::Target {
149 self.entity.as_mut().unwrap().downcast_mut().unwrap()
150 }
151}
152
153impl<'a, T> Drop for Lease<'a, T> {
154 fn drop(&mut self) {
155 if self.entity.is_some() && !panicking() {
156 panic!("Leases must be ended with EntityMap::end_lease")
157 }
158 }
159}
160
161#[derive(Deref, DerefMut)]
162pub struct Slot<T>(Model<T>);
163
164pub struct AnyModel {
165 pub(crate) entity_id: EntityId,
166 pub(crate) entity_type: TypeId,
167 entity_map: Weak<RwLock<EntityRefCounts>>,
168 #[cfg(any(test, feature = "test-support"))]
169 handle_id: HandleId,
170}
171
172impl AnyModel {
173 fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
174 Self {
175 entity_id: id,
176 entity_type,
177 entity_map: entity_map.clone(),
178 #[cfg(any(test, feature = "test-support"))]
179 handle_id: entity_map
180 .upgrade()
181 .unwrap()
182 .write()
183 .leak_detector
184 .handle_created(id),
185 }
186 }
187
188 pub fn entity_id(&self) -> EntityId {
189 self.entity_id
190 }
191
192 pub fn entity_type(&self) -> TypeId {
193 self.entity_type
194 }
195
196 pub fn downgrade(&self) -> AnyWeakModel {
197 AnyWeakModel {
198 entity_id: self.entity_id,
199 entity_type: self.entity_type,
200 entity_ref_counts: self.entity_map.clone(),
201 }
202 }
203
204 pub fn downcast<T: 'static>(self) -> Result<Model<T>, AnyModel> {
205 if TypeId::of::<T>() == self.entity_type {
206 Ok(Model {
207 any_model: self,
208 entity_type: PhantomData,
209 })
210 } else {
211 Err(self)
212 }
213 }
214}
215
216impl Clone for AnyModel {
217 fn clone(&self) -> Self {
218 if let Some(entity_map) = self.entity_map.upgrade() {
219 let entity_map = entity_map.read();
220 let count = entity_map
221 .counts
222 .get(self.entity_id)
223 .expect("detected over-release of a model");
224 let prev_count = count.fetch_add(1, SeqCst);
225 assert_ne!(prev_count, 0, "Detected over-release of a model.");
226 }
227
228 let this = Self {
229 entity_id: self.entity_id,
230 entity_type: self.entity_type,
231 entity_map: self.entity_map.clone(),
232 #[cfg(any(test, feature = "test-support"))]
233 handle_id: self
234 .entity_map
235 .upgrade()
236 .unwrap()
237 .write()
238 .leak_detector
239 .handle_created(self.entity_id),
240 };
241 this
242 }
243}
244
245impl Drop for AnyModel {
246 fn drop(&mut self) {
247 if let Some(entity_map) = self.entity_map.upgrade() {
248 let entity_map = entity_map.upgradable_read();
249 let count = entity_map
250 .counts
251 .get(self.entity_id)
252 .expect("detected over-release of a handle.");
253 let prev_count = count.fetch_sub(1, SeqCst);
254 assert_ne!(prev_count, 0, "Detected over-release of a model.");
255 if prev_count == 1 {
256 // We were the last reference to this entity, so we can remove it.
257 let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
258 entity_map.dropped_entity_ids.push(self.entity_id);
259 }
260 }
261
262 #[cfg(any(test, feature = "test-support"))]
263 if let Some(entity_map) = self.entity_map.upgrade() {
264 entity_map
265 .write()
266 .leak_detector
267 .handle_dropped(self.entity_id, self.handle_id)
268 }
269 }
270}
271
272impl<T> From<Model<T>> for AnyModel {
273 fn from(model: Model<T>) -> Self {
274 model.any_model
275 }
276}
277
278impl Hash for AnyModel {
279 fn hash<H: Hasher>(&self, state: &mut H) {
280 self.entity_id.hash(state);
281 }
282}
283
284impl PartialEq for AnyModel {
285 fn eq(&self, other: &Self) -> bool {
286 self.entity_id == other.entity_id
287 }
288}
289
290impl Eq for AnyModel {}
291
292impl std::fmt::Debug for AnyModel {
293 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294 f.debug_struct("AnyModel")
295 .field("entity_id", &self.entity_id.as_u64())
296 .finish()
297 }
298}
299
300#[derive(Deref, DerefMut)]
301pub struct Model<T> {
302 #[deref]
303 #[deref_mut]
304 pub(crate) any_model: AnyModel,
305 pub(crate) entity_type: PhantomData<T>,
306}
307
308unsafe impl<T> Send for Model<T> {}
309unsafe impl<T> Sync for Model<T> {}
310impl<T> Sealed for Model<T> {}
311
312impl<T: 'static> Entity<T> for Model<T> {
313 type Weak = WeakModel<T>;
314
315 fn entity_id(&self) -> EntityId {
316 self.any_model.entity_id
317 }
318
319 fn downgrade(&self) -> Self::Weak {
320 WeakModel {
321 any_model: self.any_model.downgrade(),
322 entity_type: self.entity_type,
323 }
324 }
325
326 fn upgrade_from(weak: &Self::Weak) -> Option<Self>
327 where
328 Self: Sized,
329 {
330 Some(Model {
331 any_model: weak.any_model.upgrade()?,
332 entity_type: weak.entity_type,
333 })
334 }
335}
336
337impl<T: 'static> Model<T> {
338 fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
339 where
340 T: 'static,
341 {
342 Self {
343 any_model: AnyModel::new(id, TypeId::of::<T>(), entity_map),
344 entity_type: PhantomData,
345 }
346 }
347
348 /// Downgrade the this to a weak model reference
349 pub fn downgrade(&self) -> WeakModel<T> {
350 // Delegate to the trait implementation to keep behavior in one place.
351 // This method was included to improve method resolution in the presence of
352 // the Model's deref
353 Entity::downgrade(self)
354 }
355
356 /// Convert this into a dynamically typed model.
357 pub fn into_any(self) -> AnyModel {
358 self.any_model
359 }
360
361 pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
362 cx.entities.read(self)
363 }
364
365 pub fn read_with<R, C: Context>(
366 &self,
367 cx: &C,
368 f: impl FnOnce(&T, &AppContext) -> R,
369 ) -> C::Result<R> {
370 cx.read_model(self, f)
371 }
372
373 /// Update the entity referenced by this model with the given function.
374 ///
375 /// The update function receives a context appropriate for its environment.
376 /// When updating in an `AppContext`, it receives a `ModelContext`.
377 /// When updating an a `WindowContext`, it receives a `ViewContext`.
378 pub fn update<C, R>(
379 &self,
380 cx: &mut C,
381 update: impl FnOnce(&mut T, &mut ModelContext<'_, T>) -> R,
382 ) -> C::Result<R>
383 where
384 C: Context,
385 {
386 cx.update_model(self, update)
387 }
388}
389
390impl<T> Clone for Model<T> {
391 fn clone(&self) -> Self {
392 Self {
393 any_model: self.any_model.clone(),
394 entity_type: self.entity_type,
395 }
396 }
397}
398
399impl<T> std::fmt::Debug for Model<T> {
400 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401 write!(
402 f,
403 "Model {{ entity_id: {:?}, entity_type: {:?} }}",
404 self.any_model.entity_id,
405 type_name::<T>()
406 )
407 }
408}
409
410impl<T> Hash for Model<T> {
411 fn hash<H: Hasher>(&self, state: &mut H) {
412 self.any_model.hash(state);
413 }
414}
415
416impl<T> PartialEq for Model<T> {
417 fn eq(&self, other: &Self) -> bool {
418 self.any_model == other.any_model
419 }
420}
421
422impl<T> Eq for Model<T> {}
423
424impl<T> PartialEq<WeakModel<T>> for Model<T> {
425 fn eq(&self, other: &WeakModel<T>) -> bool {
426 self.any_model.entity_id() == other.entity_id()
427 }
428}
429
430#[derive(Clone)]
431pub struct AnyWeakModel {
432 pub(crate) entity_id: EntityId,
433 entity_type: TypeId,
434 entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
435}
436
437impl AnyWeakModel {
438 pub fn entity_id(&self) -> EntityId {
439 self.entity_id
440 }
441
442 pub fn is_upgradable(&self) -> bool {
443 let ref_count = self
444 .entity_ref_counts
445 .upgrade()
446 .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
447 .unwrap_or(0);
448 ref_count > 0
449 }
450
451 pub fn upgrade(&self) -> Option<AnyModel> {
452 let ref_counts = &self.entity_ref_counts.upgrade()?;
453 let ref_counts = ref_counts.read();
454 let ref_count = ref_counts.counts.get(self.entity_id)?;
455
456 // entity_id is in dropped_entity_ids
457 if ref_count.load(SeqCst) == 0 {
458 return None;
459 }
460 ref_count.fetch_add(1, SeqCst);
461 drop(ref_counts);
462
463 Some(AnyModel {
464 entity_id: self.entity_id,
465 entity_type: self.entity_type,
466 entity_map: self.entity_ref_counts.clone(),
467 #[cfg(any(test, feature = "test-support"))]
468 handle_id: self
469 .entity_ref_counts
470 .upgrade()
471 .unwrap()
472 .write()
473 .leak_detector
474 .handle_created(self.entity_id),
475 })
476 }
477
478 #[cfg(any(test, feature = "test-support"))]
479 pub fn assert_dropped(&self) {
480 self.entity_ref_counts
481 .upgrade()
482 .unwrap()
483 .write()
484 .leak_detector
485 .assert_dropped(self.entity_id);
486
487 if self
488 .entity_ref_counts
489 .upgrade()
490 .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
491 .is_some()
492 {
493 panic!(
494 "entity was recently dropped but resources are retained until the end of the effect cycle."
495 )
496 }
497 }
498}
499
500impl<T> From<WeakModel<T>> for AnyWeakModel {
501 fn from(model: WeakModel<T>) -> Self {
502 model.any_model
503 }
504}
505
506impl Hash for AnyWeakModel {
507 fn hash<H: Hasher>(&self, state: &mut H) {
508 self.entity_id.hash(state);
509 }
510}
511
512impl PartialEq for AnyWeakModel {
513 fn eq(&self, other: &Self) -> bool {
514 self.entity_id == other.entity_id
515 }
516}
517
518impl Eq for AnyWeakModel {}
519
520#[derive(Deref, DerefMut)]
521pub struct WeakModel<T> {
522 #[deref]
523 #[deref_mut]
524 any_model: AnyWeakModel,
525 entity_type: PhantomData<T>,
526}
527
528unsafe impl<T> Send for WeakModel<T> {}
529unsafe impl<T> Sync for WeakModel<T> {}
530
531impl<T> Clone for WeakModel<T> {
532 fn clone(&self) -> Self {
533 Self {
534 any_model: self.any_model.clone(),
535 entity_type: self.entity_type,
536 }
537 }
538}
539
540impl<T: 'static> WeakModel<T> {
541 /// Upgrade this weak model reference into a strong model reference
542 pub fn upgrade(&self) -> Option<Model<T>> {
543 // Delegate to the trait implementation to keep behavior in one place.
544 Model::upgrade_from(self)
545 }
546
547 /// Update the entity referenced by this model with the given function if
548 /// the referenced entity still exists. Returns an error if the entity has
549 /// been released.
550 pub fn update<C, R>(
551 &self,
552 cx: &mut C,
553 update: impl FnOnce(&mut T, &mut ModelContext<'_, T>) -> R,
554 ) -> Result<R>
555 where
556 C: Context,
557 Result<C::Result<R>>: crate::Flatten<R>,
558 {
559 crate::Flatten::flatten(
560 self.upgrade()
561 .ok_or_else(|| anyhow!("entity release"))
562 .map(|this| cx.update_model(&this, update)),
563 )
564 }
565
566 /// Reads the entity referenced by this model with the given function if
567 /// the referenced entity still exists. Returns an error if the entity has
568 /// been released.
569 pub fn read_with<C, R>(&self, cx: &C, read: impl FnOnce(&T, &AppContext) -> R) -> Result<R>
570 where
571 C: Context,
572 Result<C::Result<R>>: crate::Flatten<R>,
573 {
574 crate::Flatten::flatten(
575 self.upgrade()
576 .ok_or_else(|| anyhow!("entity release"))
577 .map(|this| cx.read_model(&this, read)),
578 )
579 }
580}
581
582impl<T> Hash for WeakModel<T> {
583 fn hash<H: Hasher>(&self, state: &mut H) {
584 self.any_model.hash(state);
585 }
586}
587
588impl<T> PartialEq for WeakModel<T> {
589 fn eq(&self, other: &Self) -> bool {
590 self.any_model == other.any_model
591 }
592}
593
594impl<T> Eq for WeakModel<T> {}
595
596impl<T> PartialEq<Model<T>> for WeakModel<T> {
597 fn eq(&self, other: &Model<T>) -> bool {
598 self.entity_id() == other.any_model.entity_id()
599 }
600}
601
602#[cfg(any(test, feature = "test-support"))]
603lazy_static! {
604 static ref LEAK_BACKTRACE: bool =
605 std::env::var("LEAK_BACKTRACE").map_or(false, |b| !b.is_empty());
606}
607
608#[cfg(any(test, feature = "test-support"))]
609#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)]
610pub struct HandleId {
611 id: u64, // id of the handle itself, not the pointed at object
612}
613
614#[cfg(any(test, feature = "test-support"))]
615pub struct LeakDetector {
616 next_handle_id: u64,
617 entity_handles: HashMap<EntityId, HashMap<HandleId, Option<backtrace::Backtrace>>>,
618}
619
620#[cfg(any(test, feature = "test-support"))]
621impl LeakDetector {
622 #[track_caller]
623 pub fn handle_created(&mut self, entity_id: EntityId) -> HandleId {
624 let id = util::post_inc(&mut self.next_handle_id);
625 let handle_id = HandleId { id };
626 let handles = self.entity_handles.entry(entity_id).or_default();
627 handles.insert(
628 handle_id,
629 LEAK_BACKTRACE.then(|| backtrace::Backtrace::new_unresolved()),
630 );
631 handle_id
632 }
633
634 pub fn handle_dropped(&mut self, entity_id: EntityId, handle_id: HandleId) {
635 let handles = self.entity_handles.entry(entity_id).or_default();
636 handles.remove(&handle_id);
637 }
638
639 pub fn assert_dropped(&mut self, entity_id: EntityId) {
640 let handles = self.entity_handles.entry(entity_id).or_default();
641 if !handles.is_empty() {
642 for (_, backtrace) in handles {
643 if let Some(mut backtrace) = backtrace.take() {
644 backtrace.resolve();
645 eprintln!("Leaked handle: {:#?}", backtrace);
646 } else {
647 eprintln!("Leaked handle: export LEAK_BACKTRACE to find allocation site");
648 }
649 }
650 panic!();
651 }
652 }
653}
654
655#[cfg(test)]
656mod test {
657 use crate::EntityMap;
658
659 struct TestEntity {
660 pub i: i32,
661 }
662
663 #[test]
664 fn test_entity_map_slot_assignment_before_cleanup() {
665 // Tests that slots are not re-used before take_dropped.
666 let mut entity_map = EntityMap::new();
667
668 let slot = entity_map.reserve::<TestEntity>();
669 entity_map.insert(slot, TestEntity { i: 1 });
670
671 let slot = entity_map.reserve::<TestEntity>();
672 entity_map.insert(slot, TestEntity { i: 2 });
673
674 let dropped = entity_map.take_dropped();
675 assert_eq!(dropped.len(), 2);
676
677 assert_eq!(
678 dropped
679 .into_iter()
680 .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
681 .collect::<Vec<i32>>(),
682 vec![1, 2],
683 );
684 }
685
686 #[test]
687 fn test_entity_map_weak_upgrade_before_cleanup() {
688 // Tests that weak handles are not upgraded before take_dropped
689 let mut entity_map = EntityMap::new();
690
691 let slot = entity_map.reserve::<TestEntity>();
692 let handle = entity_map.insert(slot, TestEntity { i: 1 });
693 let weak = handle.downgrade();
694 drop(handle);
695
696 let strong = weak.upgrade();
697 assert_eq!(strong, None);
698
699 let dropped = entity_map.take_dropped();
700 assert_eq!(dropped.len(), 1);
701
702 assert_eq!(
703 dropped
704 .into_iter()
705 .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
706 .collect::<Vec<i32>>(),
707 vec![1],
708 );
709 }
710}