WIP: Allow subscribing to remote entity before creating a model

Antonio Scandurra , Nathan Sobo , and Max Brunsfeld created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
Co-Authored-By: Max Brunsfeld <max@zed.dev>

Change summary

crates/client/src/client.rs   | 287 ++++++++++++++++++++++--------------
crates/project/src/project.rs |  30 ++-
2 files changed, 193 insertions(+), 124 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -17,8 +17,7 @@ use gpui::{
     actions,
     serde_json::{self, Value},
     AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext,
-    AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext,
-    ViewHandle,
+    AsyncAppContext, Entity, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle,
 };
 use http::HttpClient;
 use lazy_static::lazy_static;
@@ -34,6 +33,7 @@ use std::{
     convert::TryFrom,
     fmt::Write as _,
     future::Future,
+    marker::PhantomData,
     path::PathBuf,
     sync::{Arc, Weak},
     time::{Duration, Instant},
@@ -172,7 +172,7 @@ struct ClientState {
     entity_id_extractors: HashMap<TypeId, fn(&dyn AnyTypedEnvelope) -> u64>,
     _reconnect_task: Option<Task<()>>,
     reconnect_interval: Duration,
-    entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>,
+    entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
     models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
     entity_types_by_message_type: HashMap<TypeId, TypeId>,
     #[allow(clippy::type_complexity)]
@@ -182,7 +182,7 @@ struct ClientState {
             dyn Send
                 + Sync
                 + Fn(
-                    AnyEntityHandle,
+                    Subscriber,
                     Box<dyn AnyTypedEnvelope>,
                     &Arc<Client>,
                     AsyncAppContext,
@@ -191,12 +191,13 @@ struct ClientState {
     >,
 }
 
-enum AnyWeakEntityHandle {
+enum WeakSubscriber {
     Model(AnyWeakModelHandle),
     View(AnyWeakViewHandle),
+    Pending(Vec<Box<dyn AnyTypedEnvelope>>),
 }
 
-enum AnyEntityHandle {
+enum Subscriber {
     Model(AnyModelHandle),
     View(AnyViewHandle),
 }
@@ -254,6 +255,54 @@ impl Drop for Subscription {
     }
 }
 
+pub struct PendingEntitySubscription<T: Entity> {
+    client: Arc<Client>,
+    remote_id: u64,
+    _entity_type: PhantomData<T>,
+    consumed: bool,
+}
+
+impl<T: Entity> PendingEntitySubscription<T> {
+    pub fn set_model(mut self, model: &ModelHandle<T>, cx: &mut AsyncAppContext) -> Subscription {
+        self.consumed = true;
+        let mut state = self.client.state.write();
+        let id = (TypeId::of::<T>(), self.remote_id);
+        let Some(WeakSubscriber::Pending(messages)) =
+            state.entities_by_type_and_remote_id.remove(&id)
+        else {
+            unreachable!()
+        };
+
+        state
+            .entities_by_type_and_remote_id
+            .insert(id, WeakSubscriber::Model(model.downgrade().into()));
+        drop(state);
+        for message in messages {
+            self.client.handle_message(message, cx);
+        }
+        Subscription::Entity {
+            client: Arc::downgrade(&self.client),
+            id,
+        }
+    }
+}
+
+impl<T: Entity> Drop for PendingEntitySubscription<T> {
+    fn drop(&mut self) {
+        if !self.consumed {
+            let mut state = self.client.state.write();
+            if let Some(WeakSubscriber::Pending(messages)) = state
+                .entities_by_type_and_remote_id
+                .remove(&(TypeId::of::<T>(), self.remote_id))
+            {
+                for message in messages {
+                    log::info!("unhandled message {}", message.payload_type_name());
+                }
+            }
+        }
+    }
+}
+
 impl Client {
     pub fn new(http: Arc<dyn HttpClient>, cx: &AppContext) -> Arc<Self> {
         Arc::new(Self {
@@ -387,26 +436,28 @@ impl Client {
         self.state
             .write()
             .entities_by_type_and_remote_id
-            .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into()));
+            .insert(id, WeakSubscriber::View(cx.weak_handle().into()));
         Subscription::Entity {
             client: Arc::downgrade(self),
             id,
         }
     }
 
-    pub fn add_model_for_remote_entity<T: Entity>(
+    pub fn subscribe_to_entity<T: Entity>(
         self: &Arc<Self>,
         remote_id: u64,
-        cx: &mut ModelContext<T>,
-    ) -> Subscription {
+    ) -> PendingEntitySubscription<T> {
         let id = (TypeId::of::<T>(), remote_id);
         self.state
             .write()
             .entities_by_type_and_remote_id
-            .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into()));
-        Subscription::Entity {
-            client: Arc::downgrade(self),
-            id,
+            .insert(id, WeakSubscriber::Pending(Default::default()));
+
+        PendingEntitySubscription {
+            client: self.clone(),
+            remote_id,
+            consumed: false,
+            _entity_type: PhantomData,
         }
     }
 
@@ -434,7 +485,7 @@ impl Client {
         let prev_handler = state.message_handlers.insert(
             message_type_id,
             Arc::new(move |handle, envelope, client, cx| {
-                let handle = if let AnyEntityHandle::Model(handle) = handle {
+                let handle = if let Subscriber::Model(handle) = handle {
                     handle
                 } else {
                     unreachable!();
@@ -488,7 +539,7 @@ impl Client {
         F: 'static + Future<Output = Result<()>>,
     {
         self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
-            if let AnyEntityHandle::View(handle) = handle {
+            if let Subscriber::View(handle) = handle {
                 handler(handle.downcast::<E>().unwrap(), message, client, cx)
             } else {
                 unreachable!();
@@ -507,7 +558,7 @@ impl Client {
         F: 'static + Future<Output = Result<()>>,
     {
         self.add_entity_message_handler::<M, E, _, _>(move |handle, message, client, cx| {
-            if let AnyEntityHandle::Model(handle) = handle {
+            if let Subscriber::Model(handle) = handle {
                 handler(handle.downcast::<E>().unwrap(), message, client, cx)
             } else {
                 unreachable!();
@@ -522,7 +573,7 @@ impl Client {
         H: 'static
             + Send
             + Sync
-            + Fn(AnyEntityHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+            + Fn(Subscriber, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
         F: 'static + Future<Output = Result<()>>,
     {
         let model_type_id = TypeId::of::<E>();
@@ -784,94 +835,8 @@ impl Client {
                 let cx = cx.clone();
                 let this = self.clone();
                 async move {
-                    let mut message_id = 0_usize;
                     while let Some(message) = incoming.next().await {
-                        let mut state = this.state.write();
-                        message_id += 1;
-                        let type_name = message.payload_type_name();
-                        let payload_type_id = message.payload_type_id();
-                        let sender_id = message.original_sender_id().map(|id| id.0);
-
-                        let model = state
-                            .models_by_message_type
-                            .get(&payload_type_id)
-                            .and_then(|model| model.upgrade(&cx))
-                            .map(AnyEntityHandle::Model)
-                            .or_else(|| {
-                                let entity_type_id =
-                                    *state.entity_types_by_message_type.get(&payload_type_id)?;
-                                let entity_id = state
-                                    .entity_id_extractors
-                                    .get(&message.payload_type_id())
-                                    .map(|extract_entity_id| {
-                                        (extract_entity_id)(message.as_ref())
-                                    })?;
-
-                                let entity = state
-                                    .entities_by_type_and_remote_id
-                                    .get(&(entity_type_id, entity_id))?;
-                                if let Some(entity) = entity.upgrade(&cx) {
-                                    Some(entity)
-                                } else {
-                                    state
-                                        .entities_by_type_and_remote_id
-                                        .remove(&(entity_type_id, entity_id));
-                                    None
-                                }
-                            });
-
-                        let model = if let Some(model) = model {
-                            model
-                        } else {
-                            log::info!("unhandled message {}", type_name);
-                            continue;
-                        };
-
-                        let handler = state.message_handlers.get(&payload_type_id).cloned();
-                        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
-                        // It also ensures we don't hold the lock while yielding back to the executor, as
-                        // that might cause the executor thread driving this future to block indefinitely.
-                        drop(state);
-
-                        if let Some(handler) = handler {
-                            let future = handler(model, message, &this, cx.clone());
-                            let client_id = this.id;
-                            log::debug!(
-                                "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
-                                client_id,
-                                message_id,
-                                sender_id,
-                                type_name
-                            );
-                            cx.foreground()
-                                .spawn(async move {
-                                    match future.await {
-                                        Ok(()) => {
-                                            log::debug!(
-                                                "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}",
-                                                client_id,
-                                                message_id,
-                                                sender_id,
-                                                type_name
-                                            );
-                                        }
-                                        Err(error) => {
-                                            log::error!(
-                                                "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}",
-                                                client_id,
-                                                message_id,
-                                                sender_id,
-                                                type_name,
-                                                error
-                                            );
-                                        }
-                                    }
-                                })
-                                .detach();
-                        } else {
-                            log::info!("unhandled message {}", type_name);
-                        }
-
+                        this.handle_message(message, &cx);
                         // Don't starve the main thread when receiving lots of messages at once.
                         smol::future::yield_now().await;
                     }
@@ -1218,6 +1183,97 @@ impl Client {
         self.peer.respond_with_error(receipt, error)
     }
 
+    fn handle_message(
+        self: &Arc<Client>,
+        message: Box<dyn AnyTypedEnvelope>,
+        cx: &AsyncAppContext,
+    ) {
+        let mut state = self.state.write();
+        let type_name = message.payload_type_name();
+        let payload_type_id = message.payload_type_id();
+        let sender_id = message.original_sender_id().map(|id| id.0);
+
+        let mut subscriber = None;
+
+        if let Some(message_model) = state
+            .models_by_message_type
+            .get(&payload_type_id)
+            .and_then(|model| model.upgrade(cx))
+        {
+            subscriber = Some(Subscriber::Model(message_model));
+        } else if let Some((extract_entity_id, entity_type_id)) =
+            state.entity_id_extractors.get(&payload_type_id).zip(
+                state
+                    .entity_types_by_message_type
+                    .get(&payload_type_id)
+                    .copied(),
+            )
+        {
+            let entity_id = (extract_entity_id)(message.as_ref());
+
+            match state
+                .entities_by_type_and_remote_id
+                .get_mut(&(entity_type_id, entity_id))
+            {
+                Some(WeakSubscriber::Pending(pending)) => {
+                    pending.push(message);
+                    return;
+                }
+                Some(weak_subscriber @ _) => subscriber = weak_subscriber.upgrade(cx),
+                _ => {}
+            }
+        }
+
+        let subscriber = if let Some(subscriber) = subscriber {
+            subscriber
+        } else {
+            log::info!("unhandled message {}", type_name);
+            return;
+        };
+
+        let handler = state.message_handlers.get(&payload_type_id).cloned();
+        // Dropping the state prevents deadlocks if the handler interacts with rpc::Client.
+        // It also ensures we don't hold the lock while yielding back to the executor, as
+        // that might cause the executor thread driving this future to block indefinitely.
+        drop(state);
+
+        if let Some(handler) = handler {
+            let future = handler(subscriber, message, &self, cx.clone());
+            let client_id = self.id;
+            log::debug!(
+                "rpc message received. client_id:{}, sender_id:{:?}, type:{}",
+                client_id,
+                sender_id,
+                type_name
+            );
+            cx.foreground()
+                .spawn(async move {
+                    match future.await {
+                        Ok(()) => {
+                            log::debug!(
+                                "rpc message handled. client_id:{}, sender_id:{:?}, type:{}",
+                                client_id,
+                                sender_id,
+                                type_name
+                            );
+                        }
+                        Err(error) => {
+                            log::error!(
+                                "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}",
+                                client_id,
+                                sender_id,
+                                type_name,
+                                error
+                            );
+                        }
+                    }
+                })
+                .detach();
+        } else {
+            log::info!("unhandled message {}", type_name);
+        }
+    }
+
     pub fn start_telemetry(&self, db: Db) {
         self.telemetry.start(db.clone());
     }
@@ -1231,11 +1287,12 @@ impl Client {
     }
 }
 
-impl AnyWeakEntityHandle {
-    fn upgrade(&self, cx: &AsyncAppContext) -> Option<AnyEntityHandle> {
+impl WeakSubscriber {
+    fn upgrade(&self, cx: &AsyncAppContext) -> Option<Subscriber> {
         match self {
-            AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model),
-            AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View),
+            WeakSubscriber::Model(handle) => handle.upgrade(cx).map(Subscriber::Model),
+            WeakSubscriber::View(handle) => handle.upgrade(cx).map(Subscriber::View),
+            WeakSubscriber::Pending(_) => None,
         }
     }
 }
@@ -1480,11 +1537,17 @@ mod tests {
             subscription: None,
         });
 
-        let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx));
-        let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx));
+        let _subscription1 = client
+            .subscribe_to_entity(1)
+            .set_model(&model1, &mut cx.to_async());
+        let _subscription2 = client
+            .subscribe_to_entity(2)
+            .set_model(&model2, &mut cx.to_async());
         // Ensure dropping a subscription for the same entity type still allows receiving of
         // messages for other entity IDs of the same type.
-        let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx));
+        let subscription3 = client
+            .subscribe_to_entity(3)
+            .set_model(&model3, &mut cx.to_async());
         drop(subscription3);
 
         server.send(proto::JoinProject { project_id: 1 });

crates/project/src/project.rs 🔗

@@ -457,22 +457,23 @@ impl Project {
     ) -> Result<ModelHandle<Self>, JoinProjectError> {
         client.authenticate_and_connect(true, &cx).await?;
 
+        let subscription = client.subscribe_to_entity(remote_id);
         let response = client
             .request(proto::JoinProject {
                 project_id: remote_id,
             })
             .await?;
+        let this = cx.add_model(|cx| {
+            let replica_id = response.replica_id as ReplicaId;
 
-        let replica_id = response.replica_id as ReplicaId;
-
-        let mut worktrees = Vec::new();
-        for worktree in response.worktrees {
-            let worktree = cx
-                .update(|cx| Worktree::remote(remote_id, replica_id, worktree, client.clone(), cx));
-            worktrees.push(worktree);
-        }
+            let mut worktrees = Vec::new();
+            for worktree in response.worktrees {
+                let worktree = cx.update(|cx| {
+                    Worktree::remote(remote_id, replica_id, worktree, client.clone(), cx)
+                });
+                worktrees.push(worktree);
+            }
 
-        let this = cx.add_model(|cx: &mut ModelContext<Self>| {
             let mut this = Self {
                 worktrees: Vec::new(),
                 loading_buffers: Default::default(),
@@ -488,7 +489,7 @@ impl Project {
                 fs,
                 next_entry_id: Default::default(),
                 next_diagnostic_group_id: Default::default(),
-                client_subscriptions: vec![client.add_model_for_remote_entity(remote_id, cx)],
+                client_subscriptions: Default::default(),
                 _subscriptions: Default::default(),
                 client: client.clone(),
                 client_state: Some(ProjectClientState::Remote {
@@ -541,6 +542,7 @@ impl Project {
             }
             this
         });
+        let subscription = subscription.set_model(&this, &mut cx);
 
         let user_ids = response
             .collaborators
@@ -558,6 +560,7 @@ impl Project {
 
         this.update(&mut cx, |this, _| {
             this.collaborators = collaborators;
+            this.client_subscriptions.push(subscription);
         });
 
         Ok(this)
@@ -1035,8 +1038,11 @@ impl Project {
             });
         }
 
-        self.client_subscriptions
-            .push(self.client.add_model_for_remote_entity(project_id, cx));
+        self.client_subscriptions.push(
+            self.client
+                .subscribe_to_entity(project_id)
+                .set_model(&cx.handle(), &mut cx.to_async()),
+        );
         let _ = self.metadata_changed(cx);
         cx.emit(Event::RemoteIdChanged(Some(project_id)));
         cx.notify();