Merge pull request #325 from zed-industries/fix-more-subscription-panics

Antonio Scandurra created

Don't register an entity ID extractor for non-entity subscriptions

Change summary

crates/client/src/client.rs | 110 ++++++++++++++++++++++++++------------
1 file changed, 76 insertions(+), 34 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -124,7 +124,7 @@ struct ClientState {
     status: (watch::Sender<Status>, watch::Receiver<Status>),
     entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
     model_handlers: HashMap<
-        (TypeId, u64),
+        (TypeId, Option<u64>),
         Option<Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>>,
     >,
     _maintain_connection: Option<Task<()>>,
@@ -152,14 +152,13 @@ impl Default for ClientState {
 
 pub struct Subscription {
     client: Weak<Client>,
-    id: (TypeId, u64),
+    id: (TypeId, Option<u64>),
 }
 
 impl Drop for Subscription {
     fn drop(&mut self) {
         if let Some(client) = self.client.upgrade() {
             let mut state = client.state.write();
-            let _ = state.entity_id_extractors.remove(&self.id.0).unwrap();
             let _ = state.model_handlers.remove(&self.id).unwrap();
         }
     }
@@ -267,18 +266,11 @@ impl Client {
             + Sync
             + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
     {
-        let subscription_id = (TypeId::of::<T>(), Default::default());
+        let subscription_id = (TypeId::of::<T>(), None);
         let client = self.clone();
         let mut state = self.state.write();
         let model = cx.weak_handle();
-        let prev_extractor = state
-            .entity_id_extractors
-            .insert(subscription_id.0, Box::new(|_| Default::default()));
-        if prev_extractor.is_some() {
-            panic!("registered a handler for the same entity twice")
-        }
-
-        state.model_handlers.insert(
+        let prev_handler = state.model_handlers.insert(
             subscription_id,
             Some(Box::new(move |envelope, cx| {
                 if let Some(model) = model.upgrade(cx) {
@@ -291,6 +283,9 @@ impl Client {
                 }
             })),
         );
+        if prev_handler.is_some() {
+            panic!("registered handler for the same message twice");
+        }
 
         Subscription {
             client: Arc::downgrade(self),
@@ -312,7 +307,7 @@ impl Client {
             + Sync
             + FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
     {
-        let subscription_id = (TypeId::of::<T>(), remote_id);
+        let subscription_id = (TypeId::of::<T>(), Some(remote_id));
         let client = self.clone();
         let mut state = self.state.write();
         let model = cx.weak_handle();
@@ -439,29 +434,27 @@ impl Client {
                 async move {
                     while let Some(message) = incoming.recv().await {
                         let mut state = this.state.write();
-                        if let Some(extract_entity_id) =
+                        let payload_type_id = message.payload_type_id();
+                        let entity_id = if let Some(extract_entity_id) =
                             state.entity_id_extractors.get(&message.payload_type_id())
                         {
-                            let payload_type_id = message.payload_type_id();
-                            let entity_id = (extract_entity_id)(message.as_ref());
-                            let handler_key = (payload_type_id, entity_id);
-                            if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
-                                let mut handler = handler.take().unwrap();
-                                drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
-                                let start_time = Instant::now();
-                                log::info!("RPC client message {}", message.payload_type_name());
-                                (handler)(message, &mut cx);
-                                log::info!(
-                                    "RPC message handled. duration:{:?}",
-                                    start_time.elapsed()
-                                );
-
-                                let mut state = this.state.write();
-                                if state.model_handlers.contains_key(&handler_key) {
-                                    state.model_handlers.insert(handler_key, Some(handler));
-                                }
-                            } else {
-                                log::info!("unhandled message {}", message.payload_type_name());
+                            Some((extract_entity_id)(message.as_ref()))
+                        } else {
+                            None
+                        };
+
+                        let handler_key = (payload_type_id, entity_id);
+                        if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
+                            let mut handler = handler.take().unwrap();
+                            drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
+                            let start_time = Instant::now();
+                            log::info!("RPC client message {}", message.payload_type_name());
+                            (handler)(message, &mut cx);
+                            log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
+
+                            let mut state = this.state.write();
+                            if state.model_handlers.contains_key(&handler_key) {
+                                state.model_handlers.insert(handler_key, Some(handler));
                             }
                         } else {
                             log::info!("unhandled message {}", message.payload_type_name());
@@ -811,6 +804,55 @@ mod tests {
         assert_eq!(decode_worktree_url("not://the-right-format"), None);
     }
 
+    #[gpui::test]
+    async fn test_subscribing_to_entity(mut cx: TestAppContext) {
+        cx.foreground().forbid_parking();
+
+        let user_id = 5;
+        let mut client = Client::new(FakeHttpClient::with_404_response());
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+        let model = cx.add_model(|_| Model { subscription: None });
+        let (mut done_tx1, mut done_rx1) = postage::oneshot::channel();
+        let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
+        let _subscription1 = model.update(&mut cx, |_, cx| {
+            client.subscribe_to_entity(
+                1,
+                cx,
+                move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
+                    postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
+                    Ok(())
+                },
+            )
+        });
+        let _subscription2 = model.update(&mut cx, |_, cx| {
+            client.subscribe_to_entity(
+                2,
+                cx,
+                move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
+                    postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
+                    Ok(())
+                },
+            )
+        });
+
+        // Ensure dropping a subscription for the same entity type still allows receiving of
+        // messages for other entity IDs of the same type.
+        let subscription3 = model.update(&mut cx, |_, cx| {
+            client.subscribe_to_entity(
+                3,
+                cx,
+                move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| Ok(()),
+            )
+        });
+        drop(subscription3);
+
+        server.send(proto::UnshareProject { project_id: 1 }).await;
+        server.send(proto::UnshareProject { project_id: 2 }).await;
+        done_rx1.recv().await.unwrap();
+        done_rx2.recv().await.unwrap();
+    }
+
     #[gpui::test]
     async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
         cx.foreground().forbid_parking();