Merge pull request #801 from zed-industries/randomized-test-improvements

Antonio Scandurra created

Introduce guest disconnection in randomized collaboration test

Change summary

crates/collab/Cargo.toml              |   1 
crates/collab/src/rpc.rs              | 570 +++++++++++++++++++---------
crates/collab/src/rpc/store.rs        |  25 -
crates/gpui_macros/src/gpui_macros.rs |  99 ++--
crates/rpc/src/peer.rs                |   6 
5 files changed, 436 insertions(+), 265 deletions(-)

Detailed changes

crates/collab/Cargo.toml 🔗

@@ -15,6 +15,7 @@ required-features = ["seed-support"]
 [dependencies]
 collections = { path = "../collections" }
 rpc = { path = "../rpc" }
+util = { path = "../util" }
 anyhow = "1.0.40"
 async-io = "1.3"
 async-std = { version = "1.8.0", features = ["attributes"] }

crates/collab/src/rpc.rs 🔗

@@ -7,12 +7,14 @@ use super::{
 };
 use anyhow::anyhow;
 use async_io::Timer;
-use async_std::task;
+use async_std::{
+    sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
+    task,
+};
 use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
 use collections::{HashMap, HashSet};
 use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt};
 use log::{as_debug, as_display};
-use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
 use rpc::{
     proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
     Connection, ConnectionId, Peer, TypedEnvelope,
@@ -21,6 +23,9 @@ use sha1::{Digest as _, Sha1};
 use std::{
     any::TypeId,
     future::Future,
+    marker::PhantomData,
+    ops::{Deref, DerefMut},
+    rc::Rc,
     sync::Arc,
     time::{Duration, Instant},
 };
@@ -31,6 +36,7 @@ use tide::{
     Request, Response,
 };
 use time::OffsetDateTime;
+use util::ResultExt;
 
 type MessageHandler = Box<
     dyn Send
@@ -58,6 +64,16 @@ pub struct RealExecutor;
 const MESSAGE_COUNT_PER_PAGE: usize = 100;
 const MAX_MESSAGE_LEN: usize = 1024;
 
+struct StoreReadGuard<'a> {
+    guard: RwLockReadGuard<'a, Store>,
+    _not_send: PhantomData<Rc<()>>,
+}
+
+struct StoreWriteGuard<'a> {
+    guard: RwLockWriteGuard<'a, Store>,
+    _not_send: PhantomData<Rc<()>>,
+}
+
 impl Server {
     pub fn new(
         app_state: Arc<AppState>,
@@ -78,7 +94,7 @@ impl Server {
             .add_message_handler(Server::unregister_project)
             .add_request_handler(Server::share_project)
             .add_message_handler(Server::unshare_project)
-            .add_request_handler(Server::join_project)
+            .add_sync_request_handler(Server::join_project)
             .add_message_handler(Server::leave_project)
             .add_request_handler(Server::register_worktree)
             .add_message_handler(Server::unregister_worktree)
@@ -170,6 +186,42 @@ impl Server {
         })
     }
 
+    /// Handle a request while holding a lock to the store. This is useful when we're registering
+    /// a connection but we want to respond on the connection before anybody else can send on it.
+    fn add_sync_request_handler<F, M>(&mut self, handler: F) -> &mut Self
+    where
+        F: 'static
+            + Send
+            + Sync
+            + Fn(Arc<Self>, &mut Store, TypedEnvelope<M>) -> tide::Result<M::Response>,
+        M: RequestMessage,
+    {
+        let handler = Arc::new(handler);
+        self.add_message_handler(move |server, envelope| {
+            let receipt = envelope.receipt();
+            let handler = handler.clone();
+            async move {
+                let mut store = server.store.write().await;
+                let response = (handler)(server.clone(), &mut *store, envelope);
+                match response {
+                    Ok(response) => {
+                        server.peer.respond(receipt, response)?;
+                        Ok(())
+                    }
+                    Err(error) => {
+                        server.peer.respond_with_error(
+                            receipt,
+                            proto::Error {
+                                message: error.to_string(),
+                            },
+                        )?;
+                        Err(error)
+                    }
+                }
+            }
+        })
+    }
+
     pub fn handle_connection<E: Executor>(
         self: &Arc<Self>,
         connection: Connection,
@@ -197,9 +249,10 @@ impl Server {
                 let _ = send_connection_id.send(connection_id).await;
             }
 
-            this.state_mut().add_connection(connection_id, user_id);
-            if let Err(err) = this.update_contacts_for_users(&[user_id]) {
-                log::error!("error updating contacts for {:?}: {}", user_id, err);
+            {
+                let mut state = this.state_mut().await;
+                state.add_connection(connection_id, user_id);
+                this.update_contacts_for_users(&*state, &[user_id]);
             }
 
             let handle_io = handle_io.fuse();
@@ -257,7 +310,8 @@ impl Server {
 
     async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
         self.peer.disconnect(connection_id);
-        let removed_connection = self.state_mut().remove_connection(connection_id)?;
+        let mut state = self.state_mut().await;
+        let removed_connection = state.remove_connection(connection_id)?;
 
         for (project_id, project) in removed_connection.hosted_projects {
             if let Some(share) = project.share {
@@ -268,7 +322,7 @@ impl Server {
                         self.peer
                             .send(conn_id, proto::UnshareProject { project_id })
                     },
-                )?;
+                );
             }
         }
 
@@ -281,10 +335,10 @@ impl Server {
                         peer_id: connection_id.0,
                     },
                 )
-            })?;
+            });
         }
 
-        self.update_contacts_for_users(removed_connection.contact_ids.iter())?;
+        self.update_contacts_for_users(&*state, removed_connection.contact_ids.iter());
         Ok(())
     }
 
@@ -293,11 +347,11 @@ impl Server {
     }
 
     async fn register_project(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::RegisterProject>,
     ) -> tide::Result<proto::RegisterProjectResponse> {
         let project_id = {
-            let mut state = self.state_mut();
+            let mut state = self.state_mut().await;
             let user_id = state.user_id_for_connection(request.sender_id)?;
             state.register_project(request.sender_id, user_id)
         };
@@ -305,51 +359,49 @@ impl Server {
     }
 
     async fn unregister_project(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::UnregisterProject>,
     ) -> tide::Result<()> {
-        let project = self
-            .state_mut()
-            .unregister_project(request.payload.project_id, request.sender_id)?;
-        self.update_contacts_for_users(project.authorized_user_ids().iter())?;
+        let mut state = self.state_mut().await;
+        let project = state.unregister_project(request.payload.project_id, request.sender_id)?;
+        self.update_contacts_for_users(&*state, &project.authorized_user_ids());
         Ok(())
     }
 
     async fn share_project(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::ShareProject>,
     ) -> tide::Result<proto::Ack> {
         self.state_mut()
+            .await
             .share_project(request.payload.project_id, request.sender_id);
         Ok(proto::Ack {})
     }
 
     async fn unshare_project(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::UnshareProject>,
     ) -> tide::Result<()> {
         let project_id = request.payload.project_id;
-        let project = self
-            .state_mut()
-            .unshare_project(project_id, request.sender_id)?;
-
+        let mut state = self.state_mut().await;
+        let project = state.unshare_project(project_id, request.sender_id)?;
         broadcast(request.sender_id, project.connection_ids, |conn_id| {
             self.peer
                 .send(conn_id, proto::UnshareProject { project_id })
-        })?;
-        self.update_contacts_for_users(&project.authorized_user_ids)?;
+        });
+        self.update_contacts_for_users(&mut *state, &project.authorized_user_ids);
         Ok(())
     }
 
-    async fn join_project(
-        mut self: Arc<Server>,
+    fn join_project(
+        self: Arc<Server>,
+        state: &mut Store,
         request: TypedEnvelope<proto::JoinProject>,
     ) -> tide::Result<proto::JoinProjectResponse> {
         let project_id = request.payload.project_id;
 
-        let user_id = self.state().user_id_for_connection(request.sender_id)?;
-        let (response, connection_ids, contact_user_ids) = self
-            .state_mut()
+        let user_id = state.user_id_for_connection(request.sender_id)?;
+        let (response, connection_ids, contact_user_ids) = state
             .join_project(request.sender_id, user_id, project_id)
             .and_then(|joined| {
                 let share = joined.project.share()?;
@@ -410,19 +462,19 @@ impl Server {
                     }),
                 },
             )
-        })?;
-        self.update_contacts_for_users(&contact_user_ids)?;
+        });
+        self.update_contacts_for_users(state, &contact_user_ids);
         Ok(response)
     }
 
     async fn leave_project(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::LeaveProject>,
     ) -> tide::Result<()> {
         let sender_id = request.sender_id;
         let project_id = request.payload.project_id;
-        let worktree = self.state_mut().leave_project(sender_id, project_id)?;
-
+        let mut state = self.state_mut().await;
+        let worktree = state.leave_project(sender_id, project_id)?;
         broadcast(sender_id, worktree.connection_ids, |conn_id| {
             self.peer.send(
                 conn_id,
@@ -431,60 +483,57 @@ impl Server {
                     peer_id: sender_id.0,
                 },
             )
-        })?;
-        self.update_contacts_for_users(&worktree.authorized_user_ids)?;
-
+        });
+        self.update_contacts_for_users(&*state, &worktree.authorized_user_ids);
         Ok(())
     }
 
     async fn register_worktree(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::RegisterWorktree>,
     ) -> tide::Result<proto::Ack> {
-        let host_user_id = self.state().user_id_for_connection(request.sender_id)?;
-
         let mut contact_user_ids = HashSet::default();
-        contact_user_ids.insert(host_user_id);
         for github_login in &request.payload.authorized_logins {
             let contact_user_id = self.app_state.db.create_user(github_login, false).await?;
             contact_user_ids.insert(contact_user_id);
         }
 
+        let mut state = self.state_mut().await;
+        let host_user_id = state.user_id_for_connection(request.sender_id)?;
+        contact_user_ids.insert(host_user_id);
+
         let contact_user_ids = contact_user_ids.into_iter().collect::<Vec<_>>();
-        let guest_connection_ids;
-        {
-            let mut state = self.state_mut();
-            guest_connection_ids = state
-                .read_project(request.payload.project_id, request.sender_id)?
-                .guest_connection_ids();
-            state.register_worktree(
-                request.payload.project_id,
-                request.payload.worktree_id,
-                request.sender_id,
-                Worktree {
-                    authorized_user_ids: contact_user_ids.clone(),
-                    root_name: request.payload.root_name.clone(),
-                    visible: request.payload.visible,
-                },
-            )?;
-        }
+        let guest_connection_ids = state
+            .read_project(request.payload.project_id, request.sender_id)?
+            .guest_connection_ids();
+        state.register_worktree(
+            request.payload.project_id,
+            request.payload.worktree_id,
+            request.sender_id,
+            Worktree {
+                authorized_user_ids: contact_user_ids.clone(),
+                root_name: request.payload.root_name.clone(),
+                visible: request.payload.visible,
+            },
+        )?;
+
         broadcast(request.sender_id, guest_connection_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
-        self.update_contacts_for_users(&contact_user_ids)?;
+        });
+        self.update_contacts_for_users(&*state, &contact_user_ids);
         Ok(proto::Ack {})
     }
 
     async fn unregister_worktree(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::UnregisterWorktree>,
     ) -> tide::Result<()> {
         let project_id = request.payload.project_id;
         let worktree_id = request.payload.worktree_id;
+        let mut state = self.state_mut().await;
         let (worktree, guest_connection_ids) =
-            self.state_mut()
-                .unregister_worktree(project_id, worktree_id, request.sender_id)?;
+            state.unregister_worktree(project_id, worktree_id, request.sender_id)?;
         broadcast(request.sender_id, guest_connection_ids, |conn_id| {
             self.peer.send(
                 conn_id,
@@ -493,16 +542,16 @@ impl Server {
                     worktree_id,
                 },
             )
-        })?;
-        self.update_contacts_for_users(&worktree.authorized_user_ids)?;
+        });
+        self.update_contacts_for_users(&*state, &worktree.authorized_user_ids);
         Ok(())
     }
 
     async fn update_worktree(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateWorktree>,
     ) -> tide::Result<proto::Ack> {
-        let connection_ids = self.state_mut().update_worktree(
+        let connection_ids = self.state_mut().await.update_worktree(
             request.sender_id,
             request.payload.project_id,
             request.payload.worktree_id,
@@ -513,13 +562,13 @@ impl Server {
         broadcast(request.sender_id, connection_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
 
         Ok(proto::Ack {})
     }
 
     async fn update_diagnostic_summary(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
     ) -> tide::Result<()> {
         let summary = request
@@ -527,7 +576,7 @@ impl Server {
             .summary
             .clone()
             .ok_or_else(|| anyhow!("invalid summary"))?;
-        let receiver_ids = self.state_mut().update_diagnostic_summary(
+        let receiver_ids = self.state_mut().await.update_diagnostic_summary(
             request.payload.project_id,
             request.payload.worktree_id,
             request.sender_id,
@@ -537,15 +586,15 @@ impl Server {
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
     async fn start_language_server(
-        mut self: Arc<Server>,
+        self: Arc<Server>,
         request: TypedEnvelope<proto::StartLanguageServer>,
     ) -> tide::Result<()> {
-        let receiver_ids = self.state_mut().start_language_server(
+        let receiver_ids = self.state_mut().await.start_language_server(
             request.payload.project_id,
             request.sender_id,
             request
@@ -557,7 +606,7 @@ impl Server {
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
@@ -567,11 +616,12 @@ impl Server {
     ) -> tide::Result<()> {
         let receiver_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
@@ -584,6 +634,7 @@ impl Server {
     {
         let host_connection_id = self
             .state()
+            .await
             .read_project(request.payload.remote_entity_id(), request.sender_id)?
             .host_connection_id;
         Ok(self
@@ -596,24 +647,25 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::SaveBuffer>,
     ) -> tide::Result<proto::BufferSaved> {
-        let host;
-        let mut guests;
-        {
-            let state = self.state();
-            let project = state.read_project(request.payload.project_id, request.sender_id)?;
-            host = project.host_connection_id;
-            guests = project.guest_connection_ids()
-        }
-
+        let host = self
+            .state()
+            .await
+            .read_project(request.payload.project_id, request.sender_id)?
+            .host_connection_id;
         let response = self
             .peer
             .forward_request(request.sender_id, host, request.payload.clone())
             .await?;
 
+        let mut guests = self
+            .state()
+            .await
+            .read_project(request.payload.project_id, request.sender_id)?
+            .connection_ids();
         guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
         broadcast(host, guests, |conn_id| {
             self.peer.forward_send(host, conn_id, response.clone())
-        })?;
+        });
 
         Ok(response)
     }
@@ -624,11 +676,12 @@ impl Server {
     ) -> tide::Result<proto::Ack> {
         let receiver_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(proto::Ack {})
     }
 
@@ -638,11 +691,12 @@ impl Server {
     ) -> tide::Result<()> {
         let receiver_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
@@ -652,11 +706,12 @@ impl Server {
     ) -> tide::Result<()> {
         let receiver_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
@@ -666,11 +721,12 @@ impl Server {
     ) -> tide::Result<()> {
         let receiver_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         broadcast(request.sender_id, receiver_ids, |connection_id| {
             self.peer
                 .forward_send(request.sender_id, connection_id, request.payload.clone())
-        })?;
+        });
         Ok(())
     }
 
@@ -682,6 +738,7 @@ impl Server {
         let follower_id = request.sender_id;
         if !self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, follower_id)?
             .contains(&leader_id)
         {
@@ -704,6 +761,7 @@ impl Server {
         let leader_id = ConnectionId(request.payload.leader_id);
         if !self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?
             .contains(&leader_id)
         {
@@ -720,6 +778,7 @@ impl Server {
     ) -> tide::Result<()> {
         let connection_ids = self
             .state()
+            .await
             .project_connection_ids(request.payload.project_id, request.sender_id)?;
         let leader_id = request
             .payload
@@ -744,7 +803,10 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::GetChannels>,
     ) -> tide::Result<proto::GetChannelsResponse> {
-        let user_id = self.state().user_id_for_connection(request.sender_id)?;
+        let user_id = self
+            .state()
+            .await
+            .user_id_for_connection(request.sender_id)?;
         let channels = self.app_state.db.get_accessible_channels(user_id).await?;
         Ok(proto::GetChannelsResponse {
             channels: channels
@@ -783,32 +845,33 @@ impl Server {
     }
 
     fn update_contacts_for_users<'a>(
-        self: &Arc<Server>,
+        self: &Arc<Self>,
+        state: &Store,
         user_ids: impl IntoIterator<Item = &'a UserId>,
-    ) -> anyhow::Result<()> {
-        let mut result = Ok(());
-        let state = self.state();
+    ) {
         for user_id in user_ids {
             let contacts = state.contacts_for_user(*user_id);
             for connection_id in state.connection_ids_for_user(*user_id) {
-                if let Err(error) = self.peer.send(
-                    connection_id,
-                    proto::UpdateContacts {
-                        contacts: contacts.clone(),
-                    },
-                ) {
-                    result = Err(error);
-                }
+                self.peer
+                    .send(
+                        connection_id,
+                        proto::UpdateContacts {
+                            contacts: contacts.clone(),
+                        },
+                    )
+                    .log_err();
             }
         }
-        result
     }
 
     async fn join_channel(
-        mut self: Arc<Self>,
+        self: Arc<Self>,
         request: TypedEnvelope<proto::JoinChannel>,
     ) -> tide::Result<proto::JoinChannelResponse> {
-        let user_id = self.state().user_id_for_connection(request.sender_id)?;
+        let user_id = self
+            .state()
+            .await
+            .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
         if !self
             .app_state
@@ -819,7 +882,9 @@ impl Server {
             Err(anyhow!("access denied"))?;
         }
 
-        self.state_mut().join_channel(request.sender_id, channel_id);
+        self.state_mut()
+            .await
+            .join_channel(request.sender_id, channel_id);
         let messages = self
             .app_state
             .db
@@ -841,10 +906,13 @@ impl Server {
     }
 
     async fn leave_channel(
-        mut self: Arc<Self>,
+        self: Arc<Self>,
         request: TypedEnvelope<proto::LeaveChannel>,
     ) -> tide::Result<()> {
-        let user_id = self.state().user_id_for_connection(request.sender_id)?;
+        let user_id = self
+            .state()
+            .await
+            .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
         if !self
             .app_state
@@ -856,6 +924,7 @@ impl Server {
         }
 
         self.state_mut()
+            .await
             .leave_channel(request.sender_id, channel_id);
 
         Ok(())
@@ -869,7 +938,7 @@ impl Server {
         let user_id;
         let connection_ids;
         {
-            let state = self.state();
+            let state = self.state().await;
             user_id = state.user_id_for_connection(request.sender_id)?;
             connection_ids = state.channel_connection_ids(channel_id)?;
         }
@@ -910,7 +979,7 @@ impl Server {
                     message: Some(message.clone()),
                 },
             )
-        })?;
+        });
         Ok(proto::SendChannelMessageResponse {
             message: Some(message),
         })
@@ -920,7 +989,10 @@ impl Server {
         self: Arc<Self>,
         request: TypedEnvelope<proto::GetChannelMessages>,
     ) -> tide::Result<proto::GetChannelMessagesResponse> {
-        let user_id = self.state().user_id_for_connection(request.sender_id)?;
+        let user_id = self
+            .state()
+            .await
+            .user_id_for_connection(request.sender_id)?;
         let channel_id = ChannelId::from_proto(request.payload.channel_id);
         if !self
             .app_state
@@ -956,12 +1028,57 @@ impl Server {
         })
     }
 
-    fn state<'a>(self: &'a Arc<Self>) -> RwLockReadGuard<'a, Store> {
-        self.store.read()
+    async fn state<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
+        #[cfg(test)]
+        async_std::task::yield_now().await;
+        let guard = self.store.read().await;
+        #[cfg(test)]
+        async_std::task::yield_now().await;
+        StoreReadGuard {
+            guard,
+            _not_send: PhantomData,
+        }
+    }
+
+    async fn state_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
+        #[cfg(test)]
+        async_std::task::yield_now().await;
+        let guard = self.store.write().await;
+        #[cfg(test)]
+        async_std::task::yield_now().await;
+        StoreWriteGuard {
+            guard,
+            _not_send: PhantomData,
+        }
+    }
+}
+
+impl<'a> Deref for StoreReadGuard<'a> {
+    type Target = Store;
+
+    fn deref(&self) -> &Self::Target {
+        &*self.guard
+    }
+}
+
+impl<'a> Deref for StoreWriteGuard<'a> {
+    type Target = Store;
+
+    fn deref(&self) -> &Self::Target {
+        &*self.guard
     }
+}
 
-    fn state_mut<'a>(self: &'a mut Arc<Self>) -> RwLockWriteGuard<'a, Store> {
-        self.store.write()
+impl<'a> DerefMut for StoreWriteGuard<'a> {
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        &mut *self.guard
+    }
+}
+
+impl<'a> Drop for StoreWriteGuard<'a> {
+    fn drop(&mut self) {
+        #[cfg(test)]
+        self.check_invariants();
     }
 }
 
@@ -977,25 +1094,15 @@ impl Executor for RealExecutor {
     }
 }
 
-fn broadcast<F>(
-    sender_id: ConnectionId,
-    receiver_ids: Vec<ConnectionId>,
-    mut f: F,
-) -> anyhow::Result<()>
+fn broadcast<F>(sender_id: ConnectionId, receiver_ids: Vec<ConnectionId>, mut f: F)
 where
     F: FnMut(ConnectionId) -> anyhow::Result<()>,
 {
-    let mut result = Ok(());
     for receiver_id in receiver_ids {
         if receiver_id != sender_id {
-            if let Err(error) = f(receiver_id) {
-                if result.is_ok() {
-                    result = Err(error);
-                }
-            }
+            f(receiver_id).log_err();
         }
     }
-    result
 }
 
 pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
@@ -1087,7 +1194,11 @@ mod tests {
         self, ConfirmCodeAction, ConfirmCompletion, ConfirmRename, Editor, Input, Redo, Rename,
         ToOffset, ToggleCodeActions, Undo,
     };
-    use gpui::{executor, geometry::vector::vec2f, ModelHandle, TestAppContext, ViewHandle};
+    use gpui::{
+        executor::{self, Deterministic},
+        geometry::vector::vec2f,
+        ModelHandle, TestAppContext, ViewHandle,
+    };
     use language::{
         range_to_lsp, tree_sitter_rust, Diagnostic, DiagnosticEntry, FakeLspAdapter, Language,
         LanguageConfig, LanguageRegistry, OffsetRangeExt, Point, Rope,
@@ -1106,7 +1217,6 @@ mod tests {
     use settings::Settings;
     use sqlx::types::time::OffsetDateTime;
     use std::{
-        cell::Cell,
         env,
         ops::Deref,
         path::{Path, PathBuf},
@@ -1118,7 +1228,6 @@ mod tests {
         time::Duration,
     };
     use theme::ThemeRegistry;
-    use util::TryFutureExt;
     use workspace::{Item, SplitDirection, ToggleFollow, Workspace, WorkspaceParams};
 
     #[cfg(test)]
@@ -4975,11 +5084,17 @@ mod tests {
     }
 
     #[gpui::test(iterations = 100)]
-    async fn test_random_collaboration(cx: &mut TestAppContext, rng: StdRng) {
+    async fn test_random_collaboration(
+        cx: &mut TestAppContext,
+        deterministic: Arc<Deterministic>,
+        rng: StdRng,
+    ) {
         cx.foreground().forbid_parking();
         let max_peers = env::var("MAX_PEERS")
             .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
             .unwrap_or(5);
+        assert!(max_peers <= 5);
+
         let max_operations = env::var("OPERATIONS")
             .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
             .unwrap_or(10);
@@ -4993,23 +5108,23 @@ mod tests {
         fs.insert_tree(
             "/_collab",
             json!({
-                ".zed.toml": r#"collaborators = ["guest-1", "guest-2", "guest-3", "guest-4", "guest-5"]"#
+                ".zed.toml": r#"collaborators = ["guest-1", "guest-2", "guest-3", "guest-4"]"#
             }),
         )
         .await;
 
-        let operations = Rc::new(Cell::new(0));
         let mut server = TestServer::start(cx.foreground(), cx.background()).await;
         let mut clients = Vec::new();
         let mut user_ids = Vec::new();
+        let mut op_start_signals = Vec::new();
         let files = Arc::new(Mutex::new(Vec::new()));
 
         let mut next_entity_id = 100000;
         let mut host_cx = TestAppContext::new(
             cx.foreground_platform(),
             cx.platform(),
-            cx.foreground(),
-            cx.background(),
+            deterministic.build_foreground(next_entity_id),
+            deterministic.build_background(),
             cx.font_cache(),
             cx.leak_detector(),
             next_entity_id,
@@ -5169,77 +5284,53 @@ mod tests {
         });
         host_language_registry.add(Arc::new(language));
 
-        let host_disconnected = Rc::new(AtomicBool::new(false));
+        let op_start_signal = futures::channel::mpsc::unbounded();
         user_ids.push(host.current_user_id(&host_cx));
-        clients.push(cx.foreground().spawn(host.simulate_host(
+        op_start_signals.push(op_start_signal.0);
+        clients.push(host_cx.foreground().spawn(host.simulate_host(
             host_project,
             files,
-            operations.clone(),
-            max_operations,
+            op_start_signal.1,
             rng.clone(),
             host_cx,
         )));
 
-        while operations.get() < max_operations {
-            cx.background().simulate_random_delay().await;
-            if clients.len() >= max_peers {
-                break;
-            } else if rng.lock().gen_bool(0.05) {
-                operations.set(operations.get() + 1);
-
-                let guest_id = clients.len();
-                log::info!("Adding guest {}", guest_id);
-                next_entity_id += 100000;
-                let mut guest_cx = TestAppContext::new(
-                    cx.foreground_platform(),
-                    cx.platform(),
-                    cx.foreground(),
-                    cx.background(),
-                    cx.font_cache(),
-                    cx.leak_detector(),
-                    next_entity_id,
-                );
-                let guest = server
-                    .create_client(&mut guest_cx, &format!("guest-{}", guest_id))
-                    .await;
-                let guest_project = Project::remote(
-                    host_project_id,
-                    guest.client.clone(),
-                    guest.user_store.clone(),
-                    guest_lang_registry.clone(),
-                    FakeFs::new(cx.background()),
-                    &mut guest_cx.to_async(),
-                )
-                .await
-                .unwrap();
-                user_ids.push(guest.current_user_id(&guest_cx));
-                clients.push(cx.foreground().spawn(guest.simulate_guest(
-                    guest_id,
-                    guest_project,
-                    operations.clone(),
-                    max_operations,
-                    rng.clone(),
-                    host_disconnected.clone(),
-                    guest_cx,
-                )));
-
-                log::info!("Guest {} added", guest_id);
-            } else if rng.lock().gen_bool(0.05) {
-                host_disconnected.store(true, SeqCst);
+        let disconnect_host_at = if rng.lock().gen_bool(0.2) {
+            rng.lock().gen_range(0..max_operations)
+        } else {
+            max_operations
+        };
+        let mut available_guests = vec![
+            "guest-1".to_string(),
+            "guest-2".to_string(),
+            "guest-3".to_string(),
+            "guest-4".to_string(),
+        ];
+        let mut operations = 0;
+        while operations < max_operations {
+            if operations == disconnect_host_at {
                 server.disconnect_client(user_ids[0]);
                 cx.foreground().advance_clock(RECEIVE_TIMEOUT);
+                drop(op_start_signals);
                 let mut clients = futures::future::join_all(clients).await;
                 cx.foreground().run_until_parked();
 
-                let (host, mut host_cx) = clients.remove(0);
+                let (host, mut host_cx, host_err) = clients.remove(0);
+                if let Some(host_err) = host_err {
+                    log::error!("host error - {}", host_err);
+                }
                 host.project
                     .as_ref()
                     .unwrap()
                     .read_with(&host_cx, |project, _| assert!(!project.is_shared()));
-                for (guest, mut guest_cx) in clients {
+                for (guest, mut guest_cx, guest_err) in clients {
+                    if let Some(guest_err) = guest_err {
+                        log::error!("{} error - {}", guest.username, guest_err);
+                    }
                     let contacts = server
                         .store
                         .read()
+                        .await
                         .contacts_for_user(guest.current_user_id(&guest_cx));
                     assert!(!contacts
                         .iter()
@@ -5256,12 +5347,113 @@ mod tests {
 
                 return;
             }
+
+            let distribution = rng.lock().gen_range(0..100);
+            match distribution {
+                0..=19 if !available_guests.is_empty() => {
+                    let guest_ix = rng.lock().gen_range(0..available_guests.len());
+                    let guest_username = available_guests.remove(guest_ix);
+                    log::info!("Adding new connection for {}", guest_username);
+                    next_entity_id += 100000;
+                    let mut guest_cx = TestAppContext::new(
+                        cx.foreground_platform(),
+                        cx.platform(),
+                        deterministic.build_foreground(next_entity_id),
+                        deterministic.build_background(),
+                        cx.font_cache(),
+                        cx.leak_detector(),
+                        next_entity_id,
+                    );
+                    let guest = server.create_client(&mut guest_cx, &guest_username).await;
+                    let guest_project = Project::remote(
+                        host_project_id,
+                        guest.client.clone(),
+                        guest.user_store.clone(),
+                        guest_lang_registry.clone(),
+                        FakeFs::new(cx.background()),
+                        &mut guest_cx.to_async(),
+                    )
+                    .await
+                    .unwrap();
+                    let op_start_signal = futures::channel::mpsc::unbounded();
+                    user_ids.push(guest.current_user_id(&guest_cx));
+                    op_start_signals.push(op_start_signal.0);
+                    clients.push(guest_cx.foreground().spawn(guest.simulate_guest(
+                        guest_username.clone(),
+                        guest_project,
+                        op_start_signal.1,
+                        rng.clone(),
+                        guest_cx,
+                    )));
+
+                    log::info!("Added connection for {}", guest_username);
+                    operations += 1;
+                }
+                20..=29 if clients.len() > 1 => {
+                    log::info!("Removing guest");
+                    let guest_ix = rng.lock().gen_range(1..clients.len());
+                    let removed_guest_id = user_ids.remove(guest_ix);
+                    let guest = clients.remove(guest_ix);
+                    op_start_signals.remove(guest_ix);
+                    server.disconnect_client(removed_guest_id);
+                    cx.foreground().advance_clock(RECEIVE_TIMEOUT);
+                    let (guest, mut guest_cx, guest_err) = guest.await;
+                    if let Some(guest_err) = guest_err {
+                        log::error!("{} error - {}", guest.username, guest_err);
+                    }
+                    guest
+                        .project
+                        .as_ref()
+                        .unwrap()
+                        .read_with(&guest_cx, |project, _| assert!(project.is_read_only()));
+                    for user_id in &user_ids {
+                        for contact in server.store.read().await.contacts_for_user(*user_id) {
+                            assert_ne!(
+                                contact.user_id, removed_guest_id.0 as u64,
+                                "removed guest is still a contact of another peer"
+                            );
+                            for project in contact.projects {
+                                for project_guest_id in project.guests {
+                                    assert_ne!(
+                                        project_guest_id, removed_guest_id.0 as u64,
+                                        "removed guest appears as still participating on a project"
+                                    );
+                                }
+                            }
+                        }
+                    }
+
+                    log::info!("{} removed", guest.username);
+                    available_guests.push(guest.username.clone());
+                    guest_cx.update(|_| drop(guest));
+
+                    operations += 1;
+                }
+                _ => {
+                    while operations < max_operations && rng.lock().gen_bool(0.7) {
+                        op_start_signals
+                            .choose(&mut *rng.lock())
+                            .unwrap()
+                            .unbounded_send(())
+                            .unwrap();
+                        operations += 1;
+                    }
+
+                    if rng.lock().gen_bool(0.8) {
+                        cx.foreground().run_until_parked();
+                    }
+                }
+            }
         }
 
+        drop(op_start_signals);
         let mut clients = futures::future::join_all(clients).await;
         cx.foreground().run_until_parked();
 
-        let (host_client, mut host_cx) = clients.remove(0);
+        let (host_client, mut host_cx, host_err) = clients.remove(0);
+        if let Some(host_err) = host_err {
+            panic!("host error - {}", host_err);
+        }
         let host_project = host_client.project.as_ref().unwrap();
         let host_worktree_snapshots = host_project.read_with(&host_cx, |project, cx| {
             project

crates/collab/src/rpc/store.rs 🔗

@@ -130,9 +130,6 @@ impl Store {
             }
         }
 
-        #[cfg(test)]
-        self.check_invariants();
-
         Ok(result)
     }
 
@@ -275,8 +272,6 @@ impl Store {
                 share.worktrees.insert(worktree_id, Default::default());
             }
 
-            #[cfg(test)]
-            self.check_invariants();
             Ok(())
         } else {
             Err(anyhow!("no such project"))?
@@ -313,8 +308,6 @@ impl Store {
                         }
                     }
 
-                    #[cfg(test)]
-                    self.check_invariants();
                     Ok(project)
                 } else {
                     Err(anyhow!("no such project"))?
@@ -359,9 +352,6 @@ impl Store {
             }
         }
 
-        #[cfg(test)]
-        self.check_invariants();
-
         Ok((worktree, guest_connection_ids))
     }
 
@@ -403,9 +393,6 @@ impl Store {
                 }
             }
 
-            #[cfg(test)]
-            self.check_invariants();
-
             Ok(UnsharedProject {
                 connection_ids,
                 authorized_user_ids,
@@ -491,9 +478,6 @@ impl Store {
         share.active_replica_ids.insert(replica_id);
         share.guests.insert(connection_id, (replica_id, user_id));
 
-        #[cfg(test)]
-        self.check_invariants();
-
         Ok(JoinedProject {
             replica_id,
             project: &self.projects[&project_id],
@@ -526,9 +510,6 @@ impl Store {
         let connection_ids = project.connection_ids();
         let authorized_user_ids = project.authorized_user_ids();
 
-        #[cfg(test)]
-        self.check_invariants();
-
         Ok(LeftProject {
             connection_ids,
             authorized_user_ids,
@@ -556,10 +537,6 @@ impl Store {
             worktree.entries.insert(entry.id, entry.clone());
         }
         let connection_ids = project.connection_ids();
-
-        #[cfg(test)]
-        self.check_invariants();
-
         Ok(connection_ids)
     }
 
@@ -633,7 +610,7 @@ impl Store {
     }
 
     #[cfg(test)]
-    fn check_invariants(&self) {
+    pub fn check_invariants(&self) {
         for (connection_id, connection) in &self.connections {
             for project_id in &connection.projects {
                 let project = &self.projects.get(&project_id).unwrap();

crates/gpui_macros/src/gpui_macros.rs 🔗

@@ -75,68 +75,65 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
                     match last_segment.map(|s| s.ident.to_string()).as_deref() {
                         Some("StdRng") => {
                             inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
+                            continue;
                         }
                         Some("bool") => {
                             inner_fn_args.extend(quote!(is_last_iteration,));
+                            continue;
                         }
-                        _ => {
-                            return TokenStream::from(
-                                syn::Error::new_spanned(arg, "invalid argument")
-                                    .into_compile_error(),
-                            )
-                        }
-                    }
-                } else if let Type::Reference(ty) = &*arg.ty {
-                    match &*ty.elem {
-                        Type::Path(ty) => {
-                            let last_segment = ty.path.segments.last();
-                            match last_segment.map(|s| s.ident.to_string()).as_deref() {
-                                Some("TestAppContext") => {
-                                    let first_entity_id = ix * 100_000;
-                                    let cx_varname = format_ident!("cx_{}", ix);
-                                    cx_vars.extend(quote!(
-                                        let mut #cx_varname = #namespace::TestAppContext::new(
-                                            foreground_platform.clone(),
-                                            cx.platform().clone(),
-                                            deterministic.build_foreground(#ix),
-                                            deterministic.build_background(),
-                                            cx.font_cache().clone(),
-                                            cx.leak_detector(),
-                                            #first_entity_id,
-                                        );
-                                    ));
-                                    cx_teardowns.extend(quote!(
-                                        #cx_varname.update(|cx| cx.remove_all_windows());
-                                        deterministic.run_until_parked();
-                                        #cx_varname.update(|_| {}); // flush effects
-                                    ));
-                                    inner_fn_args.extend(quote!(&mut #cx_varname,));
-                                }
-                                _ => {
-                                    return TokenStream::from(
-                                        syn::Error::new_spanned(arg, "invalid argument")
-                                            .into_compile_error(),
-                                    )
+                        Some("Arc") => {
+                            if let syn::PathArguments::AngleBracketed(args) =
+                                &last_segment.unwrap().arguments
+                            {
+                                if let Some(syn::GenericArgument::Type(syn::Type::Path(ty))) =
+                                    args.args.last()
+                                {
+                                    let last_segment = ty.path.segments.last();
+                                    if let Some("Deterministic") =
+                                        last_segment.map(|s| s.ident.to_string()).as_deref()
+                                    {
+                                        inner_fn_args.extend(quote!(deterministic.clone(),));
+                                        continue;
+                                    }
                                 }
                             }
                         }
-                        _ => {
-                            return TokenStream::from(
-                                syn::Error::new_spanned(arg, "invalid argument")
-                                    .into_compile_error(),
-                            )
+                        _ => {}
+                    }
+                } else if let Type::Reference(ty) = &*arg.ty {
+                    if let Type::Path(ty) = &*ty.elem {
+                        let last_segment = ty.path.segments.last();
+                        if let Some("TestAppContext") =
+                            last_segment.map(|s| s.ident.to_string()).as_deref()
+                        {
+                            let first_entity_id = ix * 100_000;
+                            let cx_varname = format_ident!("cx_{}", ix);
+                            cx_vars.extend(quote!(
+                                let mut #cx_varname = #namespace::TestAppContext::new(
+                                    foreground_platform.clone(),
+                                    cx.platform().clone(),
+                                    deterministic.build_foreground(#ix),
+                                    deterministic.build_background(),
+                                    cx.font_cache().clone(),
+                                    cx.leak_detector(),
+                                    #first_entity_id,
+                                );
+                            ));
+                            cx_teardowns.extend(quote!(
+                                #cx_varname.update(|cx| cx.remove_all_windows());
+                                deterministic.run_until_parked();
+                                #cx_varname.update(|_| {}); // flush effects
+                            ));
+                            inner_fn_args.extend(quote!(&mut #cx_varname,));
+                            continue;
                         }
                     }
-                } else {
-                    return TokenStream::from(
-                        syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
-                    );
                 }
-            } else {
-                return TokenStream::from(
-                    syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
-                );
             }
+
+            return TokenStream::from(
+                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
+            );
         }
 
         parse_quote! {

crates/rpc/src/peer.rs 🔗

@@ -126,7 +126,11 @@ impl Peer {
         // can always send messages without yielding. For incoming messages, use a
         // bounded channel so that other peers will receive backpressure if they send
         // messages faster than this peer can process them.
-        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
+        #[cfg(any(test, feature = "test-support"))]
+        const INCOMING_BUFFER_SIZE: usize = 1;
+        #[cfg(not(any(test, feature = "test-support")))]
+        const INCOMING_BUFFER_SIZE: usize = 64;
+        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
         let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
 
         let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));