@@ -125,7 +125,7 @@ struct ClientState {
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
model_handlers: HashMap<
(TypeId, u64),
- Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
+ Option<Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>>,
>,
_maintain_connection: Option<Task<()>>,
heartbeat_interval: Duration,
@@ -158,14 +158,9 @@ pub struct Subscription {
impl Drop for Subscription {
fn drop(&mut self) {
if let Some(client) = self.client.upgrade() {
- drop(
- client
- .state
- .write()
- .model_handlers
- .remove(&self.id)
- .unwrap(),
- );
+ let mut state = client.state.write();
+ let _ = state.entity_id_extractors.remove(&self.id.0).unwrap();
+ let _ = state.model_handlers.remove(&self.id).unwrap();
}
}
}
@@ -285,7 +280,7 @@ impl Client {
state.model_handlers.insert(
subscription_id,
- Box::new(move |envelope, cx| {
+ Some(Box::new(move |envelope, cx| {
if let Some(model) = model.upgrade(cx) {
let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
model.update(cx, |model, cx| {
@@ -294,7 +289,7 @@ impl Client {
}
});
}
- }),
+ })),
);
Subscription {
@@ -335,7 +330,7 @@ impl Client {
});
let prev_handler = state.model_handlers.insert(
subscription_id,
- Box::new(move |envelope, cx| {
+ Some(Box::new(move |envelope, cx| {
if let Some(model) = model.upgrade(cx) {
let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
model.update(cx, |model, cx| {
@@ -344,7 +339,7 @@ impl Client {
}
});
}
- }),
+ })),
);
if prev_handler.is_some() {
panic!("registered a handler for the same entity twice")
@@ -450,7 +445,8 @@ impl Client {
let payload_type_id = message.payload_type_id();
let entity_id = (extract_entity_id)(message.as_ref());
let handler_key = (payload_type_id, entity_id);
- if let Some(mut handler) = state.model_handlers.remove(&handler_key) {
+ if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
+ let mut handler = handler.take().unwrap();
drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
let start_time = Instant::now();
log::info!("RPC client message {}", message.payload_type_name());
@@ -459,10 +455,11 @@ impl Client {
"RPC message handled. duration:{:?}",
start_time.elapsed()
);
- this.state
- .write()
- .model_handlers
- .insert(handler_key, handler);
+
+ let mut state = this.state.write();
+ if state.model_handlers.contains_key(&handler_key) {
+ state.model_handlers.insert(handler_key, Some(handler));
+ }
} else {
log::info!("unhandled message {}", message.payload_type_name());
}
@@ -813,4 +810,64 @@ mod tests {
);
assert_eq!(decode_worktree_url("not://the-right-format"), None);
}
+
+ #[gpui::test]
+ async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
+ cx.foreground().forbid_parking();
+
+ let user_id = 5;
+ let mut client = Client::new(FakeHttpClient::with_404_response());
+ let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+ let model = cx.add_model(|_| Model { subscription: None });
+ let (mut done_tx1, _done_rx1) = postage::oneshot::channel();
+ let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
+ let subscription1 = model.update(&mut cx, |_, cx| {
+ client.subscribe(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
+ postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
+ Ok(())
+ })
+ });
+ drop(subscription1);
+ let _subscription2 = model.update(&mut cx, |_, cx| {
+ client.subscribe(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
+ postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
+ Ok(())
+ })
+ });
+ server.send(proto::Ping {}).await;
+ done_rx2.recv().await.unwrap();
+ }
+
+ #[gpui::test]
+ async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) {
+ cx.foreground().forbid_parking();
+
+ let user_id = 5;
+ let mut client = Client::new(FakeHttpClient::with_404_response());
+ let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+ let model = cx.add_model(|_| Model { subscription: None });
+ let (mut done_tx, mut done_rx) = postage::oneshot::channel();
+ model.update(&mut cx, |model, cx| {
+ model.subscription = Some(client.subscribe(
+ cx,
+ move |model, _: TypedEnvelope<proto::Ping>, _, _| {
+ model.subscription.take();
+ postage::sink::Sink::try_send(&mut done_tx, ()).unwrap();
+ Ok(())
+ },
+ ));
+ });
+ server.send(proto::Ping {}).await;
+ done_rx.recv().await.unwrap();
+ }
+
+ struct Model {
+ subscription: Option<Subscription>,
+ }
+
+ impl Entity for Model {
+ type Event = ();
+ }
}