Buffer messages in Client while no entity is listening to them

Antonio Scandurra created

Change summary

crates/client/src/channel.rs  |   2 
crates/client/src/client.rs   | 360 ++++++++++++++++++++++--------------
crates/project/src/project.rs |   8 
3 files changed, 227 insertions(+), 143 deletions(-)

Detailed changes

crates/client/src/channel.rs 🔗

@@ -190,7 +190,7 @@ impl Channel {
         rpc: Arc<Client>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
-        let _subscription = rpc.add_model_for_remote_entity(cx.handle(), details.id);
+        let _subscription = rpc.add_model_for_remote_entity(details.id, cx);
 
         {
             let user_store = user_store.clone();

crates/client/src/client.rs 🔗

@@ -13,7 +13,7 @@ use async_tungstenite::tungstenite::{
 };
 use futures::{future::LocalBoxFuture, FutureExt, StreamExt};
 use gpui::{
-    action, AnyModelHandle, AnyWeakModelHandle, AsyncAppContext, Entity, ModelHandle,
+    action, AnyModelHandle, AnyWeakModelHandle, AsyncAppContext, Entity, ModelContext, ModelHandle,
     MutableAppContext, Task,
 };
 use http::HttpClient;
@@ -140,7 +140,7 @@ struct ClientState {
     model_types_by_message_type: HashMap<TypeId, TypeId>,
     message_handlers: HashMap<
         TypeId,
-        Box<
+        Arc<
             dyn Send
                 + Sync
                 + Fn(
@@ -175,16 +175,33 @@ impl Default for ClientState {
     }
 }
 
-pub struct Subscription {
-    client: Weak<Client>,
-    id: (TypeId, u64),
+pub enum Subscription {
+    Entity {
+        client: Weak<Client>,
+        id: (TypeId, u64),
+    },
+    Message {
+        client: Weak<Client>,
+        id: TypeId,
+    },
 }
 
 impl Drop for Subscription {
     fn drop(&mut self) {
-        if let Some(client) = self.client.upgrade() {
-            let mut state = client.state.write();
-            let _ = state.models_by_entity_type_and_remote_id.remove(&self.id);
+        match self {
+            Subscription::Entity { client, id } => {
+                if let Some(client) = client.upgrade() {
+                    let mut state = client.state.write();
+                    let _ = state.models_by_entity_type_and_remote_id.remove(id);
+                }
+            }
+            Subscription::Message { client, id } => {
+                if let Some(client) = client.upgrade() {
+                    let mut state = client.state.write();
+                    let _ = state.model_types_by_message_type.remove(id);
+                    let _ = state.message_handlers.remove(id);
+                }
+            }
         }
     }
 }
@@ -285,21 +302,66 @@ impl Client {
 
     pub fn add_model_for_remote_entity<T: Entity>(
         self: &Arc<Self>,
-        handle: ModelHandle<T>,
         remote_id: u64,
+        cx: &mut ModelContext<T>,
     ) -> Subscription {
+        let handle = AnyModelHandle::from(cx.handle());
         let mut state = self.state.write();
         let id = (TypeId::of::<T>(), remote_id);
         state
             .models_by_entity_type_and_remote_id
-            .insert(id, AnyModelHandle::from(handle).downgrade());
-        Subscription {
+            .insert(id, handle.downgrade());
+        let pending_messages = state.pending_messages.remove(&id);
+        drop(state);
+
+        let client_id = self.id;
+        for message in pending_messages.into_iter().flatten() {
+            let type_id = message.payload_type_id();
+            let type_name = message.payload_type_name();
+            let state = self.state.read();
+            if let Some(handler) = state.message_handlers.get(&type_id).cloned() {
+                let future = (handler)(handle.clone(), message, cx.to_async());
+                drop(state);
+                log::debug!(
+                    "deferred rpc message received. client_id:{}, name:{}",
+                    client_id,
+                    type_name
+                );
+                cx.foreground()
+                    .spawn(async move {
+                        match future.await {
+                            Ok(()) => {
+                                log::debug!(
+                                    "deferred rpc message handled. client_id:{}, name:{}",
+                                    client_id,
+                                    type_name
+                                );
+                            }
+                            Err(error) => {
+                                log::error!(
+                                    "error handling deferred message. client_id:{}, name:{}, {}",
+                                    client_id,
+                                    type_name,
+                                    error
+                                );
+                            }
+                        }
+                    })
+                    .detach();
+            }
+        }
+
+        Subscription::Entity {
             client: Arc::downgrade(self),
             id,
         }
     }
 
-    pub fn add_message_handler<M, E, H, F>(self: &Arc<Self>, model: ModelHandle<E>, handler: H)
+    pub fn add_message_handler<M, E, H, F>(
+        self: &Arc<Self>,
+        model: ModelHandle<E>,
+        handler: H,
+    ) -> Subscription
     where
         M: EnvelopedMessage,
         E: Entity,
@@ -319,7 +381,7 @@ impl Client {
 
         let prev_handler = state.message_handlers.insert(
             message_type_id,
-            Box::new(move |handle, envelope, cx| {
+            Arc::new(move |handle, envelope, cx| {
                 let model = handle.downcast::<E>().unwrap();
                 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
                 handler(model, *envelope, client.clone(), cx).boxed_local()
@@ -328,6 +390,11 @@ impl Client {
         if prev_handler.is_some() {
             panic!("registered handler for the same message twice");
         }
+
+        Subscription::Message {
+            client: Arc::downgrade(self),
+            id: message_type_id,
+        }
     }
 
     pub fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
@@ -363,7 +430,7 @@ impl Client {
 
         let prev_handler = state.message_handlers.insert(
             message_type_id,
-            Box::new(move |handle, envelope, cx| {
+            Arc::new(move |handle, envelope, cx| {
                 let model = handle.downcast::<E>().unwrap();
                 let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
                 handler(model, *envelope, client.clone(), cx).boxed_local()
@@ -501,37 +568,54 @@ impl Client {
                         let mut state = this.state.write();
                         let payload_type_id = message.payload_type_id();
                         let type_name = message.payload_type_name();
-
-                        let model = state.models_by_message_type.get(&payload_type_id).cloned().or_else(|| {
-                            let extract_entity_id = state.entity_id_extractors.get(&message.payload_type_id())?;
-                            let entity_id = (extract_entity_id)(message.as_ref());
-                            let model_type_id = *state.model_types_by_message_type.get(&payload_type_id)?;
-
-                            // TODO - if we don't have this model yet, then buffer the message
-                            let model = state.models_by_entity_type_and_remote_id.get(&(model_type_id, entity_id))?;
-
-                            if let Some(model) = model.upgrade(&cx) {
-                                Some(model)
-                            } else {
-                                state.models_by_entity_type_and_remote_id.remove(&(model_type_id, entity_id));
-                                None
-                            }
-                        });
+                        let model_type_id = state
+                            .model_types_by_message_type
+                            .get(&payload_type_id)
+                            .copied();
+                        let entity_id = state
+                            .entity_id_extractors
+                            .get(&message.payload_type_id())
+                            .map(|extract_entity_id| (extract_entity_id)(message.as_ref()));
+
+                        let model = state
+                            .models_by_message_type
+                            .get(&payload_type_id)
+                            .cloned()
+                            .or_else(|| {
+                                let model_type_id = model_type_id?;
+                                let entity_id = entity_id?;
+                                let model = state
+                                    .models_by_entity_type_and_remote_id
+                                    .get(&(model_type_id, entity_id))?;
+                                if let Some(model) = model.upgrade(&cx) {
+                                    Some(model)
+                                } else {
+                                    state
+                                        .models_by_entity_type_and_remote_id
+                                        .remove(&(model_type_id, entity_id));
+                                    None
+                                }
+                            });
 
                         let model = if let Some(model) = model {
                             model
                         } else {
                             log::info!("unhandled message {}", type_name);
+                            if let Some((model_type_id, entity_id)) = model_type_id.zip(entity_id) {
+                                state
+                                    .pending_messages
+                                    .entry((model_type_id, entity_id))
+                                    .or_default()
+                                    .push(message);
+                            }
+
                             continue;
                         };
 
-                        if let Some(handler) = state.message_handlers.remove(&payload_type_id) {
+                        if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
+                        {
                             drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
                             let future = handler(model, message, cx.clone());
-                            {
-                                let mut state = this.state.write();
-                                state.message_handlers.insert(payload_type_id, handler);
-                            }
 
                             let client_id = this.id;
                             log::debug!(
@@ -551,7 +635,7 @@ impl Client {
                                         }
                                         Err(error) => {
                                             log::error!(
-                                                "error handling rpc message. client_id:{}, name:{}, error:{}",
+                                                "error handling message. client_id:{}, name:{}, {}",
                                                 client_id,
                                                 type_name,
                                                 error
@@ -926,109 +1010,111 @@ mod tests {
         assert_eq!(decode_worktree_url("not://the-right-format"), None);
     }
 
-    // #[gpui::test]
-    // async fn test_subscribing_to_entity(mut cx: TestAppContext) {
-    //     cx.foreground().forbid_parking();
-
-    //     let user_id = 5;
-    //     let mut client = Client::new(FakeHttpClient::with_404_response());
-    //     let server = FakeServer::for_client(user_id, &mut client, &cx).await;
-
-    //     let model = cx.add_model(|_| Model { subscription: None });
-    //     let (mut done_tx1, mut done_rx1) = postage::oneshot::channel();
-    //     let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
-    //     let _subscription1 = model.update(&mut cx, |_, cx| {
-    //         client.add_entity_message_handler(
-    //             1,
-    //             cx,
-    //             move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
-    //                 postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
-    //                 async { Ok(()) }
-    //             },
-    //         )
-    //     });
-    //     let _subscription2 = model.update(&mut cx, |_, cx| {
-    //         client.add_entity_message_handler(
-    //             2,
-    //             cx,
-    //             move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
-    //                 postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
-    //                 async { Ok(()) }
-    //             },
-    //         )
-    //     });
-
-    //     // Ensure dropping a subscription for the same entity type still allows receiving of
-    //     // messages for other entity IDs of the same type.
-    //     let subscription3 = model.update(&mut cx, |_, cx| {
-    //         client.add_entity_message_handler(
-    //             3,
-    //             cx,
-    //             |_, _: TypedEnvelope<proto::UnshareProject>, _, _| async { Ok(()) },
-    //         )
-    //     });
-    //     drop(subscription3);
-
-    //     server.send(proto::UnshareProject { project_id: 1 });
-    //     server.send(proto::UnshareProject { project_id: 2 });
-    //     done_rx1.next().await.unwrap();
-    //     done_rx2.next().await.unwrap();
-    // }
-
-    // #[gpui::test]
-    // async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
-    //     cx.foreground().forbid_parking();
-
-    //     let user_id = 5;
-    //     let mut client = Client::new(FakeHttpClient::with_404_response());
-    //     let server = FakeServer::for_client(user_id, &mut client, &cx).await;
-
-    //     let model = cx.add_model(|_| Model { subscription: None });
-    //     let (mut done_tx1, _done_rx1) = postage::oneshot::channel();
-    //     let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
-    //     let subscription1 = model.update(&mut cx, |_, cx| {
-    //         client.add_message_handler(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
-    //             postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
-    //             async { Ok(()) }
-    //         })
-    //     });
-    //     drop(subscription1);
-    //     let _subscription2 = model.update(&mut cx, |_, cx| {
-    //         client.add_message_handler(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
-    //             postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
-    //             async { Ok(()) }
-    //         })
-    //     });
-    //     server.send(proto::Ping {});
-    //     done_rx2.next().await.unwrap();
-    // }
-
-    // #[gpui::test]
-    // async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) {
-    //     cx.foreground().forbid_parking();
-
-    //     let user_id = 5;
-    //     let mut client = Client::new(FakeHttpClient::with_404_response());
-    //     let server = FakeServer::for_client(user_id, &mut client, &cx).await;
-
-    //     let model = cx.add_model(|_| Model { subscription: None });
-    //     let (mut done_tx, mut done_rx) = postage::oneshot::channel();
-    //     client.add_message_handler(
-    //         model.clone(),
-    //         move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
-    //             model.update(&mut cx, |model, _| model.subscription.take());
-    //             postage::sink::Sink::try_send(&mut done_tx, ()).unwrap();
-    //             async { Ok(()) }
-    //         },
-    //     );
-    //     model.update(&mut cx, |model, cx| {
-    //         model.subscription = Some();
-    //     });
-    //     server.send(proto::Ping {});
-    //     done_rx.next().await.unwrap();
-    // }
+    #[gpui::test]
+    async fn test_subscribing_to_entity(mut cx: TestAppContext) {
+        cx.foreground().forbid_parking();
+
+        let user_id = 5;
+        let mut client = Client::new(FakeHttpClient::with_404_response());
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+        let (done_tx1, mut done_rx1) = smol::channel::unbounded();
+        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
+        client.add_entity_message_handler(
+            move |model: ModelHandle<Model>, _: TypedEnvelope<proto::UnshareProject>, _, cx| {
+                match model.read_with(&cx, |model, _| model.id) {
+                    1 => done_tx1.try_send(()).unwrap(),
+                    2 => done_tx2.try_send(()).unwrap(),
+                    _ => unreachable!(),
+                }
+                async { Ok(()) }
+            },
+        );
+        let model1 = cx.add_model(|_| Model {
+            id: 1,
+            subscription: None,
+        });
+        let model2 = cx.add_model(|_| Model {
+            id: 2,
+            subscription: None,
+        });
+        let model3 = cx.add_model(|_| Model {
+            id: 3,
+            subscription: None,
+        });
+
+        let _subscription1 =
+            model1.update(&mut cx, |_, cx| client.add_model_for_remote_entity(1, cx));
+        let _subscription2 =
+            model2.update(&mut cx, |_, cx| client.add_model_for_remote_entity(2, cx));
+        // Ensure dropping a subscription for the same entity type still allows receiving of
+        // messages for other entity IDs of the same type.
+        let subscription3 =
+            model3.update(&mut cx, |_, cx| client.add_model_for_remote_entity(3, cx));
+        drop(subscription3);
+
+        server.send(proto::UnshareProject { project_id: 1 });
+        server.send(proto::UnshareProject { project_id: 2 });
+        done_rx1.next().await.unwrap();
+        done_rx2.next().await.unwrap();
+    }
+
+    #[gpui::test]
+    async fn test_subscribing_after_dropping_subscription(mut cx: TestAppContext) {
+        cx.foreground().forbid_parking();
+
+        let user_id = 5;
+        let mut client = Client::new(FakeHttpClient::with_404_response());
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+        let model = cx.add_model(|_| Model::default());
+        let (done_tx1, _done_rx1) = smol::channel::unbounded();
+        let (done_tx2, mut done_rx2) = smol::channel::unbounded();
+        let subscription1 = client.add_message_handler(
+            model.clone(),
+            move |_, _: TypedEnvelope<proto::Ping>, _, _| {
+                done_tx1.try_send(()).unwrap();
+                async { Ok(()) }
+            },
+        );
+        drop(subscription1);
+        let _subscription2 =
+            client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
+                done_tx2.try_send(()).unwrap();
+                async { Ok(()) }
+            });
+        server.send(proto::Ping {});
+        done_rx2.next().await.unwrap();
+    }
+
+    #[gpui::test]
+    async fn test_dropping_subscription_in_handler(mut cx: TestAppContext) {
+        cx.foreground().forbid_parking();
+
+        let user_id = 5;
+        let mut client = Client::new(FakeHttpClient::with_404_response());
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+        let model = cx.add_model(|_| Model::default());
+        let (done_tx, mut done_rx) = smol::channel::unbounded();
+        let subscription = client.add_message_handler(
+            model.clone(),
+            move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
+                model.update(&mut cx, |model, _| model.subscription.take());
+                done_tx.try_send(()).unwrap();
+                async { Ok(()) }
+            },
+        );
+        model.update(&mut cx, |model, _| {
+            model.subscription = Some(subscription);
+        });
+        server.send(proto::Ping {});
+        done_rx.next().await.unwrap();
+    }
 
+    #[derive(Default)]
     struct Model {
+        id: usize,
         subscription: Option<Subscription>,
     }
 

crates/project/src/project.rs 🔗

@@ -312,7 +312,7 @@ impl Project {
                 languages,
                 user_store,
                 fs,
-                subscriptions: vec![client.add_model_for_remote_entity(cx.handle(), remote_id)],
+                subscriptions: vec![client.add_model_for_remote_entity(remote_id, cx)],
                 client,
                 client_state: ProjectClientState::Remote {
                     sharing_has_stopped: false,
@@ -349,10 +349,8 @@ impl Project {
 
         self.subscriptions.clear();
         if let Some(remote_id) = remote_id {
-            self.subscriptions.push(
-                self.client
-                    .add_model_for_remote_entity(cx.handle(), remote_id),
-            );
+            self.subscriptions
+                .push(self.client.add_model_for_remote_entity(remote_id, cx));
         }
     }