@@ -124,7 +124,7 @@ struct ClientState {
status: (watch::Sender<Status>, watch::Receiver<Status>),
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
model_handlers: HashMap<
- (TypeId, u64),
+ (TypeId, Option<u64>),
Option<Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>>,
>,
_maintain_connection: Option<Task<()>>,
@@ -152,14 +152,13 @@ impl Default for ClientState {
pub struct Subscription {
client: Weak<Client>,
- id: (TypeId, u64),
+ id: (TypeId, Option<u64>),
}
impl Drop for Subscription {
fn drop(&mut self) {
if let Some(client) = self.client.upgrade() {
let mut state = client.state.write();
- let _ = state.entity_id_extractors.remove(&self.id.0).unwrap();
let _ = state.model_handlers.remove(&self.id).unwrap();
}
}
@@ -267,18 +266,11 @@ impl Client {
+ Sync
+ FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
{
- let subscription_id = (TypeId::of::<T>(), Default::default());
+ let subscription_id = (TypeId::of::<T>(), None);
let client = self.clone();
let mut state = self.state.write();
let model = cx.weak_handle();
- let prev_extractor = state
- .entity_id_extractors
- .insert(subscription_id.0, Box::new(|_| Default::default()));
- if prev_extractor.is_some() {
- panic!("registered a handler for the same entity twice")
- }
-
- state.model_handlers.insert(
+ let prev_handler = state.model_handlers.insert(
subscription_id,
Some(Box::new(move |envelope, cx| {
if let Some(model) = model.upgrade(cx) {
@@ -291,6 +283,9 @@ impl Client {
}
})),
);
+ if prev_handler.is_some() {
+ panic!("registered handler for the same message twice");
+ }
Subscription {
client: Arc::downgrade(self),
@@ -312,7 +307,7 @@ impl Client {
+ Sync
+ FnMut(&mut M, TypedEnvelope<T>, Arc<Self>, &mut ModelContext<M>) -> Result<()>,
{
- let subscription_id = (TypeId::of::<T>(), remote_id);
+ let subscription_id = (TypeId::of::<T>(), Some(remote_id));
let client = self.clone();
let mut state = self.state.write();
let model = cx.weak_handle();
@@ -439,29 +434,27 @@ impl Client {
async move {
while let Some(message) = incoming.recv().await {
let mut state = this.state.write();
- if let Some(extract_entity_id) =
+ let payload_type_id = message.payload_type_id();
+ let entity_id = if let Some(extract_entity_id) =
state.entity_id_extractors.get(&message.payload_type_id())
{
- let payload_type_id = message.payload_type_id();
- let entity_id = (extract_entity_id)(message.as_ref());
- let handler_key = (payload_type_id, entity_id);
- if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
- let mut handler = handler.take().unwrap();
- drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
- let start_time = Instant::now();
- log::info!("RPC client message {}", message.payload_type_name());
- (handler)(message, &mut cx);
- log::info!(
- "RPC message handled. duration:{:?}",
- start_time.elapsed()
- );
-
- let mut state = this.state.write();
- if state.model_handlers.contains_key(&handler_key) {
- state.model_handlers.insert(handler_key, Some(handler));
- }
- } else {
- log::info!("unhandled message {}", message.payload_type_name());
+ Some((extract_entity_id)(message.as_ref()))
+ } else {
+ None
+ };
+
+ let handler_key = (payload_type_id, entity_id);
+ if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
+ let mut handler = handler.take().unwrap();
+ drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
+ let start_time = Instant::now();
+ log::info!("RPC client message {}", message.payload_type_name());
+ (handler)(message, &mut cx);
+ log::info!("RPC message handled. duration:{:?}", start_time.elapsed());
+
+ let mut state = this.state.write();
+ if state.model_handlers.contains_key(&handler_key) {
+ state.model_handlers.insert(handler_key, Some(handler));
}
} else {
log::info!("unhandled message {}", message.payload_type_name());
@@ -811,6 +804,55 @@ mod tests {
assert_eq!(decode_worktree_url("not://the-right-format"), None);
}
+ #[gpui::test]
+ async fn test_subscribing_to_entity(mut cx: TestAppContext) {
+ cx.foreground().forbid_parking();
+
+ let user_id = 5;
+ let mut client = Client::new(FakeHttpClient::with_404_response());
+ let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+ let model = cx.add_model(|_| Model { subscription: None });
+ let (mut done_tx1, mut done_rx1) = postage::oneshot::channel();
+ let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
+ let _subscription1 = model.update(&mut cx, |_, cx| {
+ client.subscribe_to_entity(
+ 1,
+ cx,
+ move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
+ postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
+ Ok(())
+ },
+ )
+ });
+ let _subscription2 = model.update(&mut cx, |_, cx| {
+ client.subscribe_to_entity(
+ 2,
+ cx,
+ move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
+ postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
+ Ok(())
+ },
+ )
+ });
+
+ // Ensure dropping a subscription for the same entity type still allows receiving of
+ // messages for other entity IDs of the same type.
+ let subscription3 = model.update(&mut cx, |_, cx| {
+ client.subscribe_to_entity(
+ 3,
+ cx,
+ move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| Ok(()),
+ )
+ });
+ drop(subscription3);
+
+ server.send(proto::UnshareProject { project_id: 1 }).await;
+ server.send(proto::UnshareProject { project_id: 2 }).await;
+ done_rx1.recv().await.unwrap();
+ done_rx2.recv().await.unwrap();
+ }
+
#[gpui::test]
async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
cx.foreground().forbid_parking();