entity_map.rs

  1use crate::{AppContext, Context};
  2use anyhow::{anyhow, Result};
  3use derive_more::{Deref, DerefMut};
  4use parking_lot::{RwLock, RwLockUpgradableReadGuard};
  5use slotmap::{SecondaryMap, SlotMap};
  6use std::{
  7    any::{type_name, Any, TypeId},
  8    fmt::{self, Display},
  9    hash::{Hash, Hasher},
 10    marker::PhantomData,
 11    mem,
 12    sync::{
 13        atomic::{AtomicUsize, Ordering::SeqCst},
 14        Arc, Weak,
 15    },
 16};
 17
 18slotmap::new_key_type! { pub struct EntityId; }
 19
 20impl EntityId {
 21    pub fn as_u64(self) -> u64 {
 22        self.0.as_ffi()
 23    }
 24}
 25
 26impl Display for EntityId {
 27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 28        write!(f, "{}", self)
 29    }
 30}
 31
 32pub(crate) struct EntityMap {
 33    entities: SecondaryMap<EntityId, Box<dyn Any + Send + Sync>>,
 34    ref_counts: Arc<RwLock<EntityRefCounts>>,
 35}
 36
 37struct EntityRefCounts {
 38    counts: SlotMap<EntityId, AtomicUsize>,
 39    dropped_entity_ids: Vec<EntityId>,
 40}
 41
 42impl EntityMap {
 43    pub fn new() -> Self {
 44        Self {
 45            entities: SecondaryMap::new(),
 46            ref_counts: Arc::new(RwLock::new(EntityRefCounts {
 47                counts: SlotMap::with_key(),
 48                dropped_entity_ids: Vec::new(),
 49            })),
 50        }
 51    }
 52
 53    /// Reserve a slot for an entity, which you can subsequently use with `insert`.
 54    pub fn reserve<T: 'static + Send + Sync>(&self) -> Slot<T> {
 55        let id = self.ref_counts.write().counts.insert(1.into());
 56        Slot(Handle::new(id, Arc::downgrade(&self.ref_counts)))
 57    }
 58
 59    /// Insert an entity into a slot obtained by calling `reserve`.
 60    pub fn insert<T: 'static + Any + Send + Sync>(
 61        &mut self,
 62        slot: Slot<T>,
 63        entity: T,
 64    ) -> Handle<T> {
 65        let handle = slot.0;
 66        self.entities.insert(handle.entity_id, Box::new(entity));
 67        handle
 68    }
 69
 70    /// Move an entity to the stack.
 71    pub fn lease<'a, T: 'static + Send + Sync>(&mut self, handle: &'a Handle<T>) -> Lease<'a, T> {
 72        let entity = Some(
 73            self.entities
 74                .remove(handle.entity_id)
 75                .expect("Circular entity lease. Is the entity already being updated?")
 76                .downcast::<T>()
 77                .unwrap(),
 78        );
 79        Lease { handle, entity }
 80    }
 81
 82    /// Return an entity after moving it to the stack.
 83    pub fn end_lease<T: 'static + Send + Sync>(&mut self, mut lease: Lease<T>) {
 84        self.entities
 85            .insert(lease.handle.entity_id, lease.entity.take().unwrap());
 86    }
 87
 88    pub fn read<T: 'static + Send + Sync>(&self, handle: &Handle<T>) -> &T {
 89        self.entities[handle.entity_id].downcast_ref().unwrap()
 90    }
 91
 92    pub fn weak_handle<T: 'static + Send + Sync>(&self, id: EntityId) -> WeakHandle<T> {
 93        WeakHandle {
 94            any_handle: AnyWeakHandle {
 95                entity_id: id,
 96                entity_type: TypeId::of::<T>(),
 97                entity_ref_counts: Arc::downgrade(&self.ref_counts),
 98            },
 99            entity_type: PhantomData,
100        }
101    }
102
103    pub fn take_dropped(&mut self) -> Vec<(EntityId, Box<dyn Any + Send + Sync>)> {
104        let dropped_entity_ids = mem::take(&mut self.ref_counts.write().dropped_entity_ids);
105        dropped_entity_ids
106            .into_iter()
107            .map(|entity_id| (entity_id, self.entities.remove(entity_id).unwrap()))
108            .collect()
109    }
110}
111
112pub struct Lease<'a, T: Send + Sync> {
113    entity: Option<Box<T>>,
114    pub handle: &'a Handle<T>,
115}
116
117impl<'a, T> core::ops::Deref for Lease<'a, T>
118where
119    T: Send + Sync,
120{
121    type Target = T;
122
123    fn deref(&self) -> &Self::Target {
124        self.entity.as_ref().unwrap()
125    }
126}
127
128impl<'a, T> core::ops::DerefMut for Lease<'a, T>
129where
130    T: Send + Sync,
131{
132    fn deref_mut(&mut self) -> &mut Self::Target {
133        self.entity.as_mut().unwrap()
134    }
135}
136
137impl<'a, T> Drop for Lease<'a, T>
138where
139    T: Send + Sync,
140{
141    fn drop(&mut self) {
142        if self.entity.is_some() {
143            // We don't panic here, because other panics can cause us to drop the lease without ending it cleanly.
144            log::error!("Leases must be ended with EntityMap::end_lease")
145        }
146    }
147}
148
149#[derive(Deref, DerefMut)]
150pub struct Slot<T: Send + Sync + 'static>(Handle<T>);
151
152pub struct AnyHandle {
153    pub(crate) entity_id: EntityId,
154    entity_type: TypeId,
155    entity_map: Weak<RwLock<EntityRefCounts>>,
156}
157
158impl AnyHandle {
159    fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
160        Self {
161            entity_id: id,
162            entity_type,
163            entity_map,
164        }
165    }
166
167    pub fn entity_id(&self) -> EntityId {
168        self.entity_id
169    }
170
171    pub fn downgrade(&self) -> AnyWeakHandle {
172        AnyWeakHandle {
173            entity_id: self.entity_id,
174            entity_type: self.entity_type,
175            entity_ref_counts: self.entity_map.clone(),
176        }
177    }
178
179    pub fn downcast<T>(&self) -> Option<Handle<T>>
180    where
181        T: 'static + Send + Sync,
182    {
183        if TypeId::of::<T>() == self.entity_type {
184            Some(Handle {
185                any_handle: self.clone(),
186                entity_type: PhantomData,
187            })
188        } else {
189            None
190        }
191    }
192}
193
194impl Clone for AnyHandle {
195    fn clone(&self) -> Self {
196        if let Some(entity_map) = self.entity_map.upgrade() {
197            let entity_map = entity_map.read();
198            let count = entity_map
199                .counts
200                .get(self.entity_id)
201                .expect("detected over-release of a handle");
202            let prev_count = count.fetch_add(1, SeqCst);
203            assert_ne!(prev_count, 0, "Detected over-release of a handle.");
204        }
205
206        Self {
207            entity_id: self.entity_id,
208            entity_type: self.entity_type,
209            entity_map: self.entity_map.clone(),
210        }
211    }
212}
213
214impl Drop for AnyHandle {
215    fn drop(&mut self) {
216        if let Some(entity_map) = self.entity_map.upgrade() {
217            let entity_map = entity_map.upgradable_read();
218            let count = entity_map
219                .counts
220                .get(self.entity_id)
221                .expect("Detected over-release of a handle.");
222            let prev_count = count.fetch_sub(1, SeqCst);
223            assert_ne!(prev_count, 0, "Detected over-release of a handle.");
224            if prev_count == 1 {
225                // We were the last reference to this entity, so we can remove it.
226                let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
227                entity_map.counts.remove(self.entity_id);
228                entity_map.dropped_entity_ids.push(self.entity_id);
229            }
230        }
231    }
232}
233
234impl<T> From<Handle<T>> for AnyHandle
235where
236    T: 'static + Send + Sync,
237{
238    fn from(handle: Handle<T>) -> Self {
239        handle.any_handle
240    }
241}
242
243impl Hash for AnyHandle {
244    fn hash<H: Hasher>(&self, state: &mut H) {
245        self.entity_id.hash(state);
246    }
247}
248
249impl PartialEq for AnyHandle {
250    fn eq(&self, other: &Self) -> bool {
251        self.entity_id == other.entity_id
252    }
253}
254
255impl Eq for AnyHandle {}
256
257#[derive(Deref, DerefMut)]
258pub struct Handle<T: Send + Sync> {
259    #[deref]
260    #[deref_mut]
261    any_handle: AnyHandle,
262    entity_type: PhantomData<T>,
263}
264
265impl<T: 'static + Send + Sync> Handle<T> {
266    fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
267        Self {
268            any_handle: AnyHandle::new(id, TypeId::of::<T>(), entity_map),
269            entity_type: PhantomData,
270        }
271    }
272
273    pub fn downgrade(&self) -> WeakHandle<T> {
274        WeakHandle {
275            any_handle: self.any_handle.downgrade(),
276            entity_type: self.entity_type,
277        }
278    }
279
280    pub fn read<'a>(&self, cx: &'a AppContext) -> &'a T {
281        cx.entities.read(self)
282    }
283
284    /// Update the entity referenced by this handle with the given function.
285    ///
286    /// The update function receives a context appropriate for its environment.
287    /// When updating in an `AppContext`, it receives a `ModelContext`.
288    /// When updating an a `WindowContext`, it receives a `ViewContext`.
289    pub fn update<C: Context, R>(
290        &self,
291        cx: &mut C,
292        update: impl FnOnce(&mut T, &mut C::EntityContext<'_, '_, T>) -> R,
293    ) -> C::Result<R> {
294        cx.update_entity(self, update)
295    }
296}
297
298impl<T: Send + Sync> Clone for Handle<T> {
299    fn clone(&self) -> Self {
300        Self {
301            any_handle: self.any_handle.clone(),
302            entity_type: self.entity_type,
303        }
304    }
305}
306
307impl<T: 'static + Send + Sync> std::fmt::Debug for Handle<T> {
308    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309        write!(
310            f,
311            "Handle {{ entity_id: {:?}, entity_type: {:?} }}",
312            self.any_handle.entity_id,
313            type_name::<T>()
314        )
315    }
316}
317
318impl<T: Send + Sync + 'static> Hash for Handle<T> {
319    fn hash<H: Hasher>(&self, state: &mut H) {
320        self.any_handle.hash(state);
321    }
322}
323
324impl<T: Send + Sync + 'static> PartialEq for Handle<T> {
325    fn eq(&self, other: &Self) -> bool {
326        self.any_handle == other.any_handle
327    }
328}
329
330impl<T: Send + Sync + 'static> Eq for Handle<T> {}
331
332#[derive(Clone)]
333pub struct AnyWeakHandle {
334    pub(crate) entity_id: EntityId,
335    entity_type: TypeId,
336    entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
337}
338
339impl AnyWeakHandle {
340    pub fn entity_id(&self) -> EntityId {
341        self.entity_id
342    }
343
344    pub fn is_upgradable(&self) -> bool {
345        let ref_count = self
346            .entity_ref_counts
347            .upgrade()
348            .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
349            .unwrap_or(0);
350        ref_count > 0
351    }
352
353    pub fn upgrade(&self) -> Option<AnyHandle> {
354        let entity_map = self.entity_ref_counts.upgrade()?;
355        entity_map
356            .read()
357            .counts
358            .get(self.entity_id)?
359            .fetch_add(1, SeqCst);
360        Some(AnyHandle {
361            entity_id: self.entity_id,
362            entity_type: self.entity_type,
363            entity_map: self.entity_ref_counts.clone(),
364        })
365    }
366}
367
368impl<T> From<WeakHandle<T>> for AnyWeakHandle
369where
370    T: 'static + Send + Sync,
371{
372    fn from(handle: WeakHandle<T>) -> Self {
373        handle.any_handle
374    }
375}
376
377impl Hash for AnyWeakHandle {
378    fn hash<H: Hasher>(&self, state: &mut H) {
379        self.entity_id.hash(state);
380    }
381}
382
383impl PartialEq for AnyWeakHandle {
384    fn eq(&self, other: &Self) -> bool {
385        self.entity_id == other.entity_id
386    }
387}
388
389impl Eq for AnyWeakHandle {}
390
391#[derive(Deref, DerefMut)]
392pub struct WeakHandle<T> {
393    #[deref]
394    #[deref_mut]
395    any_handle: AnyWeakHandle,
396    entity_type: PhantomData<T>,
397}
398
399impl<T: 'static + Send + Sync> Clone for WeakHandle<T> {
400    fn clone(&self) -> Self {
401        Self {
402            any_handle: self.any_handle.clone(),
403            entity_type: self.entity_type,
404        }
405    }
406}
407
408impl<T: Send + Sync + 'static> WeakHandle<T> {
409    pub fn upgrade(&self) -> Option<Handle<T>> {
410        Some(Handle {
411            any_handle: self.any_handle.upgrade()?,
412            entity_type: self.entity_type,
413        })
414    }
415
416    /// Update the entity referenced by this handle with the given function if
417    /// the referenced entity still exists. Returns an error if the entity has
418    /// been released.
419    ///
420    /// The update function receives a context appropriate for its environment.
421    /// When updating in an `AppContext`, it receives a `ModelContext`.
422    /// When updating an a `WindowContext`, it receives a `ViewContext`.
423    pub fn update<C: Context, R>(
424        &self,
425        cx: &mut C,
426        update: impl FnOnce(&mut T, &mut C::EntityContext<'_, '_, T>) -> R,
427    ) -> Result<R>
428    where
429        Result<C::Result<R>>: crate::Flatten<R>,
430    {
431        crate::Flatten::flatten(
432            self.upgrade()
433                .ok_or_else(|| anyhow!("entity release"))
434                .map(|this| cx.update_entity(&this, update)),
435        )
436    }
437}
438
439impl<T: Send + Sync + 'static> Hash for WeakHandle<T> {
440    fn hash<H: Hasher>(&self, state: &mut H) {
441        self.any_handle.hash(state);
442    }
443}
444
445impl<T: Send + Sync + 'static> PartialEq for WeakHandle<T> {
446    fn eq(&self, other: &Self) -> bool {
447        self.any_handle == other.any_handle
448    }
449}
450
451impl<T: Send + Sync + 'static> Eq for WeakHandle<T> {}