Avoid reference cycle between `Client` and its models

Antonio Scandurra created

Change summary

crates/client/src/client.rs | 23 +++++++++++++++--------
crates/gpui/src/app.rs      |  9 +++++++++
2 files changed, 24 insertions(+), 8 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -133,9 +133,8 @@ struct ClientState {
     entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
     _maintain_connection: Option<Task<()>>,
     heartbeat_interval: Duration,
-
     models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>,
-    models_by_message_type: HashMap<TypeId, AnyModelHandle>,
+    models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
     model_types_by_message_type: HashMap<TypeId, TypeId>,
     message_handlers: HashMap<
         TypeId,
@@ -348,18 +347,22 @@ impl Client {
     {
         let message_type_id = TypeId::of::<M>();
 
-        let client = self.clone();
+        let client = Arc::downgrade(self);
         let mut state = self.state.write();
         state
             .models_by_message_type
-            .insert(message_type_id, model.into());
+            .insert(message_type_id, model.downgrade().into());
 
         let prev_handler = state.message_handlers.insert(
             message_type_id,
             Arc::new(move |handle, envelope, cx| {
                 let model = handle.downcast::<E>().unwrap();
                 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
-                handler(model, *envelope, client.clone(), cx).boxed_local()
+                if let Some(client) = client.upgrade() {
+                    handler(model, *envelope, client.clone(), cx).boxed_local()
+                } else {
+                    async move { Ok(()) }.boxed_local()
+                }
             }),
         );
         if prev_handler.is_some() {
@@ -385,7 +388,7 @@ impl Client {
         let model_type_id = TypeId::of::<E>();
         let message_type_id = TypeId::of::<M>();
 
-        let client = self.clone();
+        let client = Arc::downgrade(self);
         let mut state = self.state.write();
         state
             .model_types_by_message_type
@@ -408,7 +411,11 @@ impl Client {
             Arc::new(move |handle, envelope, cx| {
                 let model = handle.downcast::<E>().unwrap();
                 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
-                handler(model, *envelope, client.clone(), cx).boxed_local()
+                if let Some(client) = client.upgrade() {
+                    handler(model, *envelope, client.clone(), cx).boxed_local()
+                } else {
+                    async move { Ok(()) }.boxed_local()
+                }
             }),
         );
         if prev_handler.is_some() {
@@ -550,7 +557,7 @@ impl Client {
                         let model = state
                             .models_by_message_type
                             .get(&payload_type_id)
-                            .cloned()
+                            .and_then(|model| model.upgrade(&cx))
                             .or_else(|| {
                                 let model_type_id =
                                     *state.model_types_by_message_type.get(&payload_type_id)?;

crates/gpui/src/app.rs 🔗

@@ -3567,6 +3567,15 @@ impl AnyWeakModelHandle {
     }
 }
 
+impl<T: Entity> From<WeakModelHandle<T>> for AnyWeakModelHandle {
+    fn from(handle: WeakModelHandle<T>) -> Self {
+        AnyWeakModelHandle {
+            model_id: handle.model_id,
+            model_type: TypeId::of::<T>(),
+        }
+    }
+}
+
 pub struct WeakViewHandle<T> {
     window_id: usize,
     view_id: usize,