:art: client

Max Brunsfeld created

Forgot to push this yesterday night.

Change summary

crates/client/src/client.rs | 198 +++++++++++++++-----------------------
1 file changed, 80 insertions(+), 118 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -136,7 +136,7 @@ impl Status {
 struct ClientState {
     credentials: Option<Credentials>,
     status: (watch::Sender<Status>, watch::Receiver<Status>),
-    entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
+    entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
     _reconnect_task: Option<Task<()>>,
     reconnect_interval: Duration,
     entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
@@ -150,6 +150,7 @@ struct ClientState {
                 + Fn(
                     AnyEntityHandle,
                     Box<dyn AnyTypedEnvelope>,
+                    &Arc<Client>,
                     AsyncAppContext,
                 ) -> LocalBoxFuture<'static, Result<()>>,
         >,
@@ -328,12 +329,11 @@ impl Client {
         remote_id: u64,
         cx: &mut ViewContext<T>,
     ) -> Subscription {
-        let handle = AnyViewHandle::from(cx.handle());
-        let mut state = self.state.write();
         let id = (TypeId::of::<T>(), remote_id);
-        state
+        self.state
+            .write()
             .entities_by_type_and_remote_id
-            .insert(id, AnyWeakEntityHandle::View(handle.downgrade()));
+            .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
         Subscription::Entity {
             client: Arc::downgrade(self),
             id,
@@ -345,12 +345,11 @@ impl Client {
         remote_id: u64,
         cx: &mut ModelContext<T>,
     ) -> Subscription {
-        let handle = AnyModelHandle::from(cx.handle());
-        let mut state = self.state.write();
         let id = (TypeId::of::<T>(), remote_id);
-        state
+        self.state
+            .write()
             .entities_by_type_and_remote_id
-            .insert(id, AnyWeakEntityHandle::Model(handle.downgrade()));
+            .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
         Subscription::Entity {
             client: Arc::downgrade(self),
             id,
@@ -373,7 +372,6 @@ impl Client {
     {
         let message_type_id = TypeId::of::<M>();
 
-        let client = Arc::downgrade(self);
         let mut state = self.state.write();
         state
             .models_by_message_type
@@ -381,7 +379,7 @@ impl Client {
 
         let prev_handler = state.message_handlers.insert(
             message_type_id,
-            Arc::new(move |handle, envelope, cx| {
+            Arc::new(move |handle, envelope, client, cx| {
                 let handle = if let AnyEntityHandle::Model(handle) = handle {
                     handle
                 } else {
@@ -389,11 +387,7 @@ impl Client {
                 };
                 let model = handle.downcast::<E>().unwrap();
                 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
-                if let Some(client) = client.upgrade() {
-                    handler(model, *envelope, client.clone(), cx).boxed_local()
-                } else {
-                    async move { Ok(()) }.boxed_local()
-                }
+                handler(model, *envelope, client.clone(), cx).boxed_local()
             }),
         );
         if prev_handler.is_some() {
@@ -416,47 +410,13 @@ impl Client {
             + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
         F: 'static + Future<Output = Result<()>>,
     {
-        let entity_type_id = TypeId::of::<E>();
-        let message_type_id = TypeId::of::<M>();
-
-        let client = Arc::downgrade(self);
-        let mut state = self.state.write();
-        state
-            .entity_types_by_message_type
-            .insert(message_type_id, entity_type_id);
-        state
-            .entity_id_extractors
-            .entry(message_type_id)
-            .or_insert_with(|| {
-                Box::new(|envelope| {
-                    let envelope = envelope
-                        .as_any()
-                        .downcast_ref::<TypedEnvelope<M>>()
-                        .unwrap();
-                    envelope.payload.remote_entity_id()
-                })
-            });
-
-        let prev_handler = state.message_handlers.insert(
-            message_type_id,
-            Arc::new(move |handle, envelope, cx| {
-                let handle = if let AnyEntityHandle::View(handle) = handle {
-                    handle
-                } else {
-                    unreachable!();
-                };
-                let model = handle.downcast::<E>().unwrap();
-                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
-                if let Some(client) = client.upgrade() {
-                    handler(model, *envelope, client.clone(), cx).boxed_local()
-                } else {
-                    async move { Ok(()) }.boxed_local()
-                }
-            }),
-        );
-        if prev_handler.is_some() {
-            panic!("registered handler for the same message twice");
-        }
+        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
+            if let AnyEntityHandle::View(handle) = handle {
+                handler(handle.downcast::<E>().unwrap(), message, client, cx)
+            } else {
+                unreachable!();
+            }
+        })
     }
 
     pub fn add_model_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
@@ -468,11 +428,29 @@ impl Client {
             + Sync
             + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
         F: 'static + Future<Output = Result<()>>,
+    {
+        self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
+            if let AnyEntityHandle::Model(handle) = handle {
+                handler(handle.downcast::<E>().unwrap(), message, client, cx)
+            } else {
+                unreachable!();
+            }
+        })
+    }
+
+    fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
+    where
+        M: EntityMessage,
+        E: Entity,
+        H: 'static
+            + Send
+            + Sync
+            + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+        F: 'static + Future<Output = Result<()>>,
     {
         let model_type_id = TypeId::of::<E>();
         let message_type_id = TypeId::of::<M>();
 
-        let client = Arc::downgrade(self);
         let mut state = self.state.write();
         state
             .entity_types_by_message_type
@@ -481,30 +459,20 @@ impl Client {
             .entity_id_extractors
             .entry(message_type_id)
             .or_insert_with(|| {
-                Box::new(|envelope| {
-                    let envelope = envelope
+                |envelope| {
+                    envelope
                         .as_any()
                         .downcast_ref::<TypedEnvelope<M>>()
-                        .unwrap();
-                    envelope.payload.remote_entity_id()
-                })
+                        .unwrap()
+                        .payload
+                        .remote_entity_id()
+                }
             });
-
         let prev_handler = state.message_handlers.insert(
             message_type_id,
-            Arc::new(move |handle, envelope, cx| {
-                if let Some(client) = client.upgrade() {
-                    let handle = if let AnyEntityHandle::Model(handle) = handle {
-                        handle
-                    } else {
-                        unreachable!();
-                    };
-                    let model = handle.downcast::<E>().unwrap();
-                    let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
-                    handler(model, *envelope, client.clone(), cx).boxed_local()
-                } else {
-                    async move { Ok(()) }.boxed_local()
-                }
+            Arc::new(move |handle, envelope, client, cx| {
+                let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
+                handler(handle, *envelope, client.clone(), cx).boxed_local()
             }),
         );
         if prev_handler.is_some() {
@@ -522,26 +490,12 @@ impl Client {
             + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
         F: 'static + Future<Output = Result<M::Response>>,
     {
-        self.add_model_message_handler(move |model, envelope, client, cx| {
-            let receipt = envelope.receipt();
-            let response = handler(model, envelope, client.clone(), cx);
-            async move {
-                match response.await {
-                    Ok(response) => {
-                        client.respond(receipt, response)?;
-                        Ok(())
-                    }
-                    Err(error) => {
-                        client.respond_with_error(
-                            receipt,
-                            proto::Error {
-                                message: error.to_string(),
-                            },
-                        )?;
-                        Err(error)
-                    }
-                }
-            }
+        self.add_model_message_handler(move |entity, envelope, client, cx| {
+            Self::respond_to_request::<M, _>(
+                envelope.receipt(),
+                handler(entity, envelope, client.clone(), cx),
+                client,
+            )
         })
     }
 
@@ -555,29 +509,37 @@ impl Client {
             + Fn(ViewHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
         F: 'static + Future<Output = Result<M::Response>>,
     {
-        self.add_view_message_handler(move |view, envelope, client, cx| {
-            let receipt = envelope.receipt();
-            let response = handler(view, envelope, client.clone(), cx);
-            async move {
-                match response.await {
-                    Ok(response) => {
-                        client.respond(receipt, response)?;
-                        Ok(())
-                    }
-                    Err(error) => {
-                        client.respond_with_error(
-                            receipt,
-                            proto::Error {
-                                message: error.to_string(),
-                            },
-                        )?;
-                        Err(error)
-                    }
-                }
-            }
+        self.add_view_message_handler(move |entity, envelope, client, cx| {
+            Self::respond_to_request::<M, _>(
+                envelope.receipt(),
+                handler(entity, envelope, client.clone(), cx),
+                client,
+            )
         })
     }
 
+    async fn respond_to_request<T: RequestMessage, F: Future<Output = Result<T::Response>>>(
+        receipt: Receipt<T>,
+        response: F,
+        client: Arc<Self>,
+    ) -> Result<()> {
+        match response.await {
+            Ok(response) => {
+                client.respond(receipt, response)?;
+                Ok(())
+            }
+            Err(error) => {
+                client.respond_with_error(
+                    receipt,
+                    proto::Error {
+                        message: error.to_string(),
+                    },
+                )?;
+                Err(error)
+            }
+        }
+    }
+
     pub fn has_keychain_credentials(&self, cx: &AsyncAppContext) -> bool {
         read_credentials_from_keychain(cx).is_some()
     }
@@ -718,7 +680,7 @@ impl Client {
                         if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
                         {
                             drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
-                            let future = handler(model, message, cx.clone());
+                            let future = handler(model, message, &this, cx.clone());
 
                             let client_id = this.id;
                             log::debug!(