Merge pull request #316 from zed-industries/fix-subscription-panic

Max Brunsfeld created

Fix `rpc::Client` subscription panics

Change summary

crates/client/Cargo.toml    |  6 ++
crates/client/src/client.rs | 93 +++++++++++++++++++++++++++++++-------
2 files changed, 80 insertions(+), 19 deletions(-)

Detailed changes

crates/client/Cargo.toml 🔗

@@ -7,7 +7,7 @@ edition = "2018"
 path = "src/client.rs"
 
 [features]
-test-support = ["rpc/test-support"]
+test-support = ["gpui/test-support", "rpc/test-support"]
 
 [dependencies]
 gpui = { path = "../gpui" }
@@ -29,3 +29,7 @@ surf = "2.2"
 thiserror = "1.0.29"
 time = "0.3"
 tiny_http = "0.8"
+
+[dev-dependencies]
+gpui = { path = "../gpui", features = ["test-support"] }
+rpc = { path = "../rpc", features = ["test-support"] }

crates/client/src/client.rs 🔗

@@ -125,7 +125,7 @@ struct ClientState {
     entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
     model_handlers: HashMap<
         (TypeId, u64),
-        Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
+        Option<Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>>,
     >,
     _maintain_connection: Option<Task<()>>,
     heartbeat_interval: Duration,
@@ -158,14 +158,9 @@ pub struct Subscription {
 impl Drop for Subscription {
     fn drop(&mut self) {
         if let Some(client) = self.client.upgrade() {
-            drop(
-                client
-                    .state
-                    .write()
-                    .model_handlers
-                    .remove(&self.id)
-                    .unwrap(),
-            );
+            let mut state = client.state.write();
+            let _ = state.entity_id_extractors.remove(&self.id.0).unwrap();
+            let _ = state.model_handlers.remove(&self.id).unwrap();
         }
     }
 }
@@ -285,7 +280,7 @@ impl Client {
 
         state.model_handlers.insert(
             subscription_id,
-            Box::new(move |envelope, cx| {
+            Some(Box::new(move |envelope, cx| {
                 if let Some(model) = model.upgrade(cx) {
                     let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
                     model.update(cx, |model, cx| {
@@ -294,7 +289,7 @@ impl Client {
                         }
                     });
                 }
-            }),
+            })),
         );
 
         Subscription {
@@ -335,7 +330,7 @@ impl Client {
             });
         let prev_handler = state.model_handlers.insert(
             subscription_id,
-            Box::new(move |envelope, cx| {
+            Some(Box::new(move |envelope, cx| {
                 if let Some(model) = model.upgrade(cx) {
                     let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
                     model.update(cx, |model, cx| {
@@ -344,7 +339,7 @@ impl Client {
                         }
                     });
                 }
-            }),
+            })),
         );
         if prev_handler.is_some() {
             panic!("registered a handler for the same entity twice")
@@ -450,7 +445,8 @@ impl Client {
                             let payload_type_id = message.payload_type_id();
                             let entity_id = (extract_entity_id)(message.as_ref());
                             let handler_key = (payload_type_id, entity_id);
-                            if let Some(mut handler) = state.model_handlers.remove(&handler_key) {
+                            if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
+                                let mut handler = handler.take().unwrap();
                                 drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
                                 let start_time = Instant::now();
                                 log::info!("RPC client message {}", message.payload_type_name());
@@ -459,10 +455,11 @@ impl Client {
                                     "RPC message handled. duration:{:?}",
                                     start_time.elapsed()
                                 );
-                                this.state
-                                    .write()
-                                    .model_handlers
-                                    .insert(handler_key, handler);
+
+                                let mut state = this.state.write();
+                                if state.model_handlers.contains_key(&handler_key) {
+                                    state.model_handlers.insert(handler_key, Some(handler));
+                                }
                             } else {
                                 log::info!("unhandled message {}", message.payload_type_name());
                             }
@@ -813,4 +810,64 @@ mod tests {
         );
         assert_eq!(decode_worktree_url("not://the-right-format"), None);
     }
+
+    #[gpui::test]
+    async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
+        cx.foreground().forbid_parking();
+
+        let user_id = 5;
+        let mut client = Client::new(FakeHttpClient::with_404_response());
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+        let model = cx.add_model(|_| Model { subscription: None });
+        let (mut done_tx1, _done_rx1) = postage::oneshot::channel();
+        let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
+        let subscription1 = model.update(&mut cx, |_, cx| {
+            client.subscribe(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
+                postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
+                Ok(())
+            })
+        });
+        drop(subscription1);
+        let _subscription2 = model.update(&mut cx, |_, cx| {
+            client.subscribe(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
+                postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
+                Ok(())
+            })
+        });
+        server.send(proto::Ping {}).await;
+        done_rx2.recv().await.unwrap();
+    }
+
+    #[gpui::test]
+    async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) {
+        cx.foreground().forbid_parking();
+
+        let user_id = 5;
+        let mut client = Client::new(FakeHttpClient::with_404_response());
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+        let model = cx.add_model(|_| Model { subscription: None });
+        let (mut done_tx, mut done_rx) = postage::oneshot::channel();
+        model.update(&mut cx, |model, cx| {
+            model.subscription = Some(client.subscribe(
+                cx,
+                move |model, _: TypedEnvelope<proto::Ping>, _, _| {
+                    model.subscription.take();
+                    postage::sink::Sink::try_send(&mut done_tx, ()).unwrap();
+                    Ok(())
+                },
+            ));
+        });
+        server.send(proto::Ping {}).await;
+        done_rx.recv().await.unwrap();
+    }
+
+    struct Model {
+        subscription: Option<Subscription>,
+    }
+
+    impl Entity for Model {
+        type Event = ();
+    }
 }