1use crate::{AnyBox, AppContext, Context};
2use anyhow::{anyhow, Result};
3use derive_more::{Deref, DerefMut};
4use parking_lot::{RwLock, RwLockUpgradableReadGuard};
5use slotmap::{SecondaryMap, SlotMap};
6use std::{
7 any::{type_name, Any, TypeId},
8 fmt::{self, Display},
9 hash::{Hash, Hasher},
10 marker::PhantomData,
11 mem,
12 sync::{
13 atomic::{AtomicUsize, Ordering::SeqCst},
14 Arc, Weak,
15 },
16};
17
18slotmap::new_key_type! { pub struct EntityId; }
19
20impl EntityId {
21 pub fn as_u64(self) -> u64 {
22 self.0.as_ffi()
23 }
24}
25
26impl Display for EntityId {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 write!(f, "{}", self)
29 }
30}
31
32pub(crate) struct EntityMap {
33 entities: SecondaryMap<EntityId, AnyBox>,
34 ref_counts: Arc<RwLock<EntityRefCounts>>,
35}
36
37struct EntityRefCounts {
38 counts: SlotMap<EntityId, AtomicUsize>,
39 dropped_entity_ids: Vec<EntityId>,
40}
41
42impl EntityMap {
43 pub fn new() -> Self {
44 Self {
45 entities: SecondaryMap::new(),
46 ref_counts: Arc::new(RwLock::new(EntityRefCounts {
47 counts: SlotMap::with_key(),
48 dropped_entity_ids: Vec::new(),
49 })),
50 }
51 }
52
53 /// Reserve a slot for an entity, which you can subsequently use with `insert`.
54 pub fn reserve<T: 'static>(&self) -> Slot<T> {
55 let id = self.ref_counts.write().counts.insert(1.into());
56 Slot(Handle::new(id, Arc::downgrade(&self.ref_counts)))
57 }
58
59 /// Insert an entity into a slot obtained by calling `reserve`.
60 pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Handle<T>
61 where
62 T: Any + Send + Sync,
63 {
64 let handle = slot.0;
65 self.entities.insert(handle.entity_id, Box::new(entity));
66 handle
67 }
68
69 /// Move an entity to the stack.
70 pub fn lease<'a, T>(&mut self, handle: &'a Handle<T>) -> Lease<'a, T> {
71 self.assert_valid_context(handle);
72 let entity = Some(
73 self.entities
74 .remove(handle.entity_id)
75 .expect("Circular entity lease. Is the entity already being updated?"),
76 );
77 Lease {
78 handle,
79 entity,
80 entity_type: PhantomData,
81 }
82 }
83
84 /// Return an entity after moving it to the stack.
85 pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
86 self.entities
87 .insert(lease.handle.entity_id, lease.entity.take().unwrap());
88 }
89
90 pub fn read<T: 'static>(&self, handle: &Handle<T>) -> &T {
91 self.assert_valid_context(handle);
92 self.entities[handle.entity_id].downcast_ref().unwrap()
93 }
94
95 fn assert_valid_context(&self, handle: &AnyHandle) {
96 debug_assert!(
97 Weak::ptr_eq(&handle.entity_map, &Arc::downgrade(&self.ref_counts)),
98 "used a handle with the wrong context"
99 );
100 }
101
102 pub fn take_dropped(&mut self) -> Vec<(EntityId, AnyBox)> {
103 let mut ref_counts = self.ref_counts.write();
104 let dropped_entity_ids = mem::take(&mut ref_counts.dropped_entity_ids);
105
106 dropped_entity_ids
107 .into_iter()
108 .map(|entity_id| {
109 let count = ref_counts.counts.remove(entity_id).unwrap();
110 debug_assert_eq!(
111 count.load(SeqCst),
112 0,
113 "dropped an entity that was referenced"
114 );
115 (entity_id, self.entities.remove(entity_id).unwrap())
116 })
117 .collect()
118 }
119}
120
121pub struct Lease<'a, T> {
122 entity: Option<AnyBox>,
123 pub handle: &'a Handle<T>,
124 entity_type: PhantomData<T>,
125}
126
127impl<'a, T: 'static> core::ops::Deref for Lease<'a, T> {
128 type Target = T;
129
130 fn deref(&self) -> &Self::Target {
131 self.entity.as_ref().unwrap().downcast_ref().unwrap()
132 }
133}
134
135impl<'a, T: 'static> core::ops::DerefMut for Lease<'a, T> {
136 fn deref_mut(&mut self) -> &mut Self::Target {
137 self.entity.as_mut().unwrap().downcast_mut().unwrap()
138 }
139}
140
141impl<'a, T> Drop for Lease<'a, T> {
142 fn drop(&mut self) {
143 if self.entity.is_some() {
144 // We don't panic here, because other panics can cause us to drop the lease without ending it cleanly.
145 log::error!("Leases must be ended with EntityMap::end_lease")
146 }
147 }
148}
149
150#[derive(Deref, DerefMut)]
151pub struct Slot<T>(Handle<T>);
152
153pub struct AnyHandle {
154 pub(crate) entity_id: EntityId,
155 entity_type: TypeId,
156 entity_map: Weak<RwLock<EntityRefCounts>>,
157}
158
159impl AnyHandle {
160 fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
161 Self {
162 entity_id: id,
163 entity_type,
164 entity_map,
165 }
166 }
167
168 pub fn entity_id(&self) -> EntityId {
169 self.entity_id
170 }
171
172 pub fn downgrade(&self) -> AnyWeakHandle {
173 AnyWeakHandle {
174 entity_id: self.entity_id,
175 entity_type: self.entity_type,
176 entity_ref_counts: self.entity_map.clone(),
177 }
178 }
179
180 pub fn downcast<T: 'static>(&self) -> Option<Handle<T>> {
181 if TypeId::of::<T>() == self.entity_type {
182 Some(Handle {
183 any_handle: self.clone(),
184 entity_type: PhantomData,
185 })
186 } else {
187 None
188 }
189 }
190}
191
192impl Clone for AnyHandle {
193 fn clone(&self) -> Self {
194 if let Some(entity_map) = self.entity_map.upgrade() {
195 let entity_map = entity_map.read();
196 let count = entity_map
197 .counts
198 .get(self.entity_id)
199 .expect("detected over-release of a handle");
200 let prev_count = count.fetch_add(1, SeqCst);
201 assert_ne!(prev_count, 0, "Detected over-release of a handle.");
202 }
203
204 Self {
205 entity_id: self.entity_id,
206 entity_type: self.entity_type,
207 entity_map: self.entity_map.clone(),
208 }
209 }
210}
211
212impl Drop for AnyHandle {
213 fn drop(&mut self) {
214 if let Some(entity_map) = self.entity_map.upgrade() {
215 let entity_map = entity_map.upgradable_read();
216 let count = entity_map
217 .counts
218 .get(self.entity_id)
219 .expect("detected over-release of a handle.");
220 let prev_count = count.fetch_sub(1, SeqCst);
221 assert_ne!(prev_count, 0, "Detected over-release of a handle.");
222 if prev_count == 1 {
223 // We were the last reference to this entity, so we can remove it.
224 let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
225 entity_map.dropped_entity_ids.push(self.entity_id);
226 }
227 }
228 }
229}
230
231impl<T> From<Handle<T>> for AnyHandle {
232 fn from(handle: Handle<T>) -> Self {
233 handle.any_handle
234 }
235}
236
237impl Hash for AnyHandle {
238 fn hash<H: Hasher>(&self, state: &mut H) {
239 self.entity_id.hash(state);
240 }
241}
242
243impl PartialEq for AnyHandle {
244 fn eq(&self, other: &Self) -> bool {
245 self.entity_id == other.entity_id
246 }
247}
248
249impl Eq for AnyHandle {}
250
251#[derive(Deref, DerefMut)]
252pub struct Handle<T> {
253 #[deref]
254 #[deref_mut]
255 any_handle: AnyHandle,
256 entity_type: PhantomData<T>,
257}
258
259unsafe impl<T> Send for Handle<T> {}
260unsafe impl<T> Sync for Handle<T> {}
261
262impl<T: 'static> Handle<T> {
263 fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
264 where
265 T: 'static,
266 {
267 Self {
268 any_handle: AnyHandle::new(id, TypeId::of::<T>(), entity_map),
269 entity_type: PhantomData,
270 }
271 }
272
273 pub fn downgrade(&self) -> WeakHandle<T> {
274 WeakHandle {
275 any_handle: self.any_handle.downgrade(),
276 entity_type: self.entity_type,
277 }
278 }
279
280 pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
281 cx.entities.read(self)
282 }
283
284 /// Update the entity referenced by this handle with the given function.
285 ///
286 /// The update function receives a context appropriate for its environment.
287 /// When updating in an `AppContext`, it receives a `ModelContext`.
288 /// When updating an a `WindowContext`, it receives a `ViewContext`.
289 pub fn update<C, R>(
290 &self,
291 cx: &mut C,
292 update: impl FnOnce(&mut T, &mut C::EntityContext<'_, '_, T>) -> R,
293 ) -> C::Result<R>
294 where
295 C: Context,
296 {
297 cx.update_entity(self, update)
298 }
299}
300
301impl<T> Clone for Handle<T> {
302 fn clone(&self) -> Self {
303 Self {
304 any_handle: self.any_handle.clone(),
305 entity_type: self.entity_type,
306 }
307 }
308}
309
310impl<T> std::fmt::Debug for Handle<T> {
311 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312 write!(
313 f,
314 "Handle {{ entity_id: {:?}, entity_type: {:?} }}",
315 self.any_handle.entity_id,
316 type_name::<T>()
317 )
318 }
319}
320
321impl<T> Hash for Handle<T> {
322 fn hash<H: Hasher>(&self, state: &mut H) {
323 self.any_handle.hash(state);
324 }
325}
326
327impl<T> PartialEq for Handle<T> {
328 fn eq(&self, other: &Self) -> bool {
329 self.any_handle == other.any_handle
330 }
331}
332
333impl<T> Eq for Handle<T> {}
334
335impl<T> PartialEq<WeakHandle<T>> for Handle<T> {
336 fn eq(&self, other: &WeakHandle<T>) -> bool {
337 self.entity_id() == other.entity_id()
338 }
339}
340
341#[derive(Clone)]
342pub struct AnyWeakHandle {
343 pub(crate) entity_id: EntityId,
344 entity_type: TypeId,
345 entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
346}
347
348impl AnyWeakHandle {
349 pub fn entity_id(&self) -> EntityId {
350 self.entity_id
351 }
352
353 pub fn is_upgradable(&self) -> bool {
354 let ref_count = self
355 .entity_ref_counts
356 .upgrade()
357 .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
358 .unwrap_or(0);
359 ref_count > 0
360 }
361
362 pub fn upgrade(&self) -> Option<AnyHandle> {
363 let entity_map = self.entity_ref_counts.upgrade()?;
364 let entity_map = entity_map.read();
365 let ref_count = entity_map.counts.get(self.entity_id)?;
366
367 // entity_id is in dropped_entity_ids
368 if ref_count.load(SeqCst) == 0 {
369 return None;
370 }
371 ref_count.fetch_add(1, SeqCst);
372
373 Some(AnyHandle {
374 entity_id: self.entity_id,
375 entity_type: self.entity_type,
376 entity_map: self.entity_ref_counts.clone(),
377 })
378 }
379}
380
381impl<T> From<WeakHandle<T>> for AnyWeakHandle {
382 fn from(handle: WeakHandle<T>) -> Self {
383 handle.any_handle
384 }
385}
386
387impl Hash for AnyWeakHandle {
388 fn hash<H: Hasher>(&self, state: &mut H) {
389 self.entity_id.hash(state);
390 }
391}
392
393impl PartialEq for AnyWeakHandle {
394 fn eq(&self, other: &Self) -> bool {
395 self.entity_id == other.entity_id
396 }
397}
398
399impl Eq for AnyWeakHandle {}
400
401#[derive(Deref, DerefMut)]
402pub struct WeakHandle<T> {
403 #[deref]
404 #[deref_mut]
405 any_handle: AnyWeakHandle,
406 entity_type: PhantomData<T>,
407}
408
409unsafe impl<T> Send for WeakHandle<T> {}
410unsafe impl<T> Sync for WeakHandle<T> {}
411
412impl<T> Clone for WeakHandle<T> {
413 fn clone(&self) -> Self {
414 Self {
415 any_handle: self.any_handle.clone(),
416 entity_type: self.entity_type,
417 }
418 }
419}
420
421impl<T: 'static> WeakHandle<T> {
422 pub fn upgrade(&self) -> Option<Handle<T>> {
423 Some(Handle {
424 any_handle: self.any_handle.upgrade()?,
425 entity_type: self.entity_type,
426 })
427 }
428
429 /// Update the entity referenced by this handle with the given function if
430 /// the referenced entity still exists. Returns an error if the entity has
431 /// been released.
432 ///
433 /// The update function receives a context appropriate for its environment.
434 /// When updating in an `AppContext`, it receives a `ModelContext`.
435 /// When updating an a `WindowContext`, it receives a `ViewContext`.
436 pub fn update<C, R>(
437 &self,
438 cx: &mut C,
439 update: impl FnOnce(&mut T, &mut C::EntityContext<'_, '_, T>) -> R,
440 ) -> Result<R>
441 where
442 C: Context,
443 Result<C::Result<R>>: crate::Flatten<R>,
444 {
445 crate::Flatten::flatten(
446 self.upgrade()
447 .ok_or_else(|| anyhow!("entity release"))
448 .map(|this| cx.update_entity(&this, update)),
449 )
450 }
451}
452
453impl<T> Hash for WeakHandle<T> {
454 fn hash<H: Hasher>(&self, state: &mut H) {
455 self.any_handle.hash(state);
456 }
457}
458
459impl<T> PartialEq for WeakHandle<T> {
460 fn eq(&self, other: &Self) -> bool {
461 self.any_handle == other.any_handle
462 }
463}
464
465impl<T> Eq for WeakHandle<T> {}
466
467impl<T> PartialEq<Handle<T>> for WeakHandle<T> {
468 fn eq(&self, other: &Handle<T>) -> bool {
469 self.entity_id() == other.entity_id()
470 }
471}
472
473#[cfg(test)]
474mod test {
475 use crate::EntityMap;
476
477 struct TestEntity {
478 pub i: i32,
479 }
480
481 #[test]
482 fn test_entity_map_slot_assignment_before_cleanup() {
483 // Tests that slots are not re-used before take_dropped.
484 let mut entity_map = EntityMap::new();
485
486 let slot = entity_map.reserve::<TestEntity>();
487 entity_map.insert(slot, TestEntity { i: 1 });
488
489 let slot = entity_map.reserve::<TestEntity>();
490 entity_map.insert(slot, TestEntity { i: 2 });
491
492 let dropped = entity_map.take_dropped();
493 assert_eq!(dropped.len(), 2);
494
495 assert_eq!(
496 dropped
497 .into_iter()
498 .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
499 .collect::<Vec<i32>>(),
500 vec![1, 2],
501 );
502 }
503
504 #[test]
505 fn test_entity_map_weak_upgrade_before_cleanup() {
506 // Tests that weak handles are not upgraded before take_dropped
507 let mut entity_map = EntityMap::new();
508
509 let slot = entity_map.reserve::<TestEntity>();
510 let handle = entity_map.insert(slot, TestEntity { i: 1 });
511 let weak = handle.downgrade();
512 drop(handle);
513
514 let strong = weak.upgrade();
515 assert_eq!(strong, None);
516
517 let dropped = entity_map.take_dropped();
518 assert_eq!(dropped.len(), 1);
519
520 assert_eq!(
521 dropped
522 .into_iter()
523 .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
524 .collect::<Vec<i32>>(),
525 vec![1],
526 );
527 }
528}