Distribute operation workload evenly across peers in randomized test

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/server/src/rpc.rs | 148 ++++++++++++++++++++++-------------------
1 file changed, 81 insertions(+), 67 deletions(-)

Detailed changes

crates/server/src/rpc.rs 🔗

@@ -1110,7 +1110,6 @@ mod tests {
     use settings::Settings;
     use sqlx::types::time::OffsetDateTime;
     use std::{
-        cell::Cell,
         env,
         ops::Deref,
         path::{Path, PathBuf},
@@ -5000,10 +4999,10 @@ mod tests {
         )
         .await;
 
-        let operations = Rc::new(Cell::new(0));
         let mut server = TestServer::start(cx.foreground(), cx.background()).await;
         let mut clients = Vec::new();
         let mut user_ids = Vec::new();
+        let mut op_start_signals = Vec::new();
         let files = Arc::new(Mutex::new(Vec::new()));
 
         let mut next_entity_id = 100000;
@@ -5172,64 +5171,29 @@ mod tests {
         host_language_registry.add(Arc::new(language));
 
         let host_disconnected = Rc::new(AtomicBool::new(false));
+        let op_start_signal = futures::channel::mpsc::unbounded();
         user_ids.push(host.current_user_id(&host_cx));
+        op_start_signals.push(op_start_signal.0);
         clients.push(host_cx.foreground().spawn(host.simulate_host(
             host_project,
             files,
-            operations.clone(),
-            max_operations,
+            op_start_signal.1,
             rng.clone(),
             host_cx,
         )));
 
-        while operations.get() < max_operations {
-            cx.background().simulate_random_delay().await;
-            if clients.len() >= max_peers {
-                break;
-            } else if rng.lock().gen_bool(0.05) {
-                operations.set(operations.get() + 1);
-
-                let guest_id = clients.len();
-                log::info!("Adding guest {}", guest_id);
-                next_entity_id += 100000;
-                let mut guest_cx = TestAppContext::new(
-                    cx.foreground_platform(),
-                    cx.platform(),
-                    deterministic.build_foreground(next_entity_id),
-                    deterministic.build_background(),
-                    cx.font_cache(),
-                    cx.leak_detector(),
-                    next_entity_id,
-                );
-                let guest = server
-                    .create_client(&mut guest_cx, &format!("guest-{}", guest_id))
-                    .await;
-                let guest_project = Project::remote(
-                    host_project_id,
-                    guest.client.clone(),
-                    guest.user_store.clone(),
-                    guest_lang_registry.clone(),
-                    FakeFs::new(cx.background()),
-                    &mut guest_cx.to_async(),
-                )
-                .await
-                .unwrap();
-                user_ids.push(guest.current_user_id(&guest_cx));
-                clients.push(guest_cx.foreground().spawn(guest.simulate_guest(
-                    guest_id,
-                    guest_project,
-                    operations.clone(),
-                    max_operations,
-                    rng.clone(),
-                    host_disconnected.clone(),
-                    guest_cx,
-                )));
-
-                log::info!("Guest {} added", guest_id);
-            } else if rng.lock().gen_bool(0.05) {
+        let disconnect_host_at = if rng.lock().gen_bool(0.2) {
+            rng.lock().gen_range(0..max_operations)
+        } else {
+            max_operations
+        };
+        let mut operations = 0;
+        while operations < max_operations {
+            if operations == disconnect_host_at {
                 host_disconnected.store(true, SeqCst);
                 server.disconnect_client(user_ids[0]);
                 cx.foreground().advance_clock(RECEIVE_TIMEOUT);
+                drop(op_start_signals);
                 let mut clients = futures::future::join_all(clients).await;
                 cx.foreground().run_until_parked();
 
@@ -5258,8 +5222,68 @@ mod tests {
 
                 return;
             }
+
+            let distribution = rng.lock().gen_range(0..100);
+            match distribution {
+                0..=19 if clients.len() < max_peers => {
+                    let guest_id = clients.len();
+                    log::info!("Adding guest {}", guest_id);
+                    next_entity_id += 100000;
+                    let mut guest_cx = TestAppContext::new(
+                        cx.foreground_platform(),
+                        cx.platform(),
+                        deterministic.build_foreground(next_entity_id),
+                        deterministic.build_background(),
+                        cx.font_cache(),
+                        cx.leak_detector(),
+                        next_entity_id,
+                    );
+                    let guest = server
+                        .create_client(&mut guest_cx, &format!("guest-{}", guest_id))
+                        .await;
+                    let guest_project = Project::remote(
+                        host_project_id,
+                        guest.client.clone(),
+                        guest.user_store.clone(),
+                        guest_lang_registry.clone(),
+                        FakeFs::new(cx.background()),
+                        &mut guest_cx.to_async(),
+                    )
+                    .await
+                    .unwrap();
+                    let op_start_signal = futures::channel::mpsc::unbounded();
+                    user_ids.push(guest.current_user_id(&guest_cx));
+                    op_start_signals.push(op_start_signal.0);
+                    clients.push(guest_cx.foreground().spawn(guest.simulate_guest(
+                        guest_id,
+                        guest_project,
+                        op_start_signal.1,
+                        rng.clone(),
+                        host_disconnected.clone(),
+                        guest_cx,
+                    )));
+
+                    log::info!("Guest {} added", guest_id);
+                    operations += 1;
+                }
+                _ => {
+                    while operations < max_operations && rng.lock().gen_bool(0.7) {
+                        op_start_signals
+                            .choose(&mut *rng.lock())
+                            .unwrap()
+                            .unbounded_send(())
+                            .unwrap();
+                        operations += 1;
+                    }
+
+                    if rng.lock().gen_bool(0.8) {
+                        cx.foreground().run_until_parked();
+                    }
+                }
+            }
         }
 
+        drop(op_start_signals);
         let mut clients = futures::future::join_all(clients).await;
         cx.foreground().run_until_parked();
 
@@ -5655,8 +5679,7 @@ mod tests {
             mut self,
             project: ModelHandle<Project>,
             files: Arc<Mutex<Vec<PathBuf>>>,
-            operations: Rc<Cell<usize>>,
-            max_operations: usize,
+            op_start_signal: futures::channel::mpsc::UnboundedReceiver<()>,
             rng: Arc<Mutex<StdRng>>,
             mut cx: TestAppContext,
         ) -> (Self, TestAppContext) {
@@ -5664,15 +5687,13 @@ mod tests {
                 client: &mut TestClient,
                 project: ModelHandle<Project>,
                 files: Arc<Mutex<Vec<PathBuf>>>,
-                operations: Rc<Cell<usize>>,
-                max_operations: usize,
+                mut op_start_signal: futures::channel::mpsc::UnboundedReceiver<()>,
                 rng: Arc<Mutex<StdRng>>,
                 cx: &mut TestAppContext,
             ) -> anyhow::Result<()> {
                 let fs = project.read_with(cx, |project, _| project.fs().clone());
-                while operations.get() < max_operations {
-                    operations.set(operations.get() + 1);
 
+                while op_start_signal.next().await.is_some() {
                     let distribution = rng.lock().gen_range::<usize, _>(0..100);
                     match distribution {
                         0..=20 if !files.lock().is_empty() => {
@@ -5784,8 +5805,7 @@ mod tests {
                 &mut self,
                 project.clone(),
                 files,
-                operations,
-                max_operations,
+                op_start_signal,
                 rng,
                 &mut cx,
             )
@@ -5800,8 +5820,7 @@ mod tests {
             mut self,
             guest_id: usize,
             project: ModelHandle<Project>,
-            operations: Rc<Cell<usize>>,
-            max_operations: usize,
+            op_start_signal: futures::channel::mpsc::UnboundedReceiver<()>,
             rng: Arc<Mutex<StdRng>>,
             host_disconnected: Rc<AtomicBool>,
             mut cx: TestAppContext,
@@ -5810,12 +5829,11 @@ mod tests {
                 client: &mut TestClient,
                 guest_id: usize,
                 project: ModelHandle<Project>,
-                operations: Rc<Cell<usize>>,
-                max_operations: usize,
+                mut op_start_signal: futures::channel::mpsc::UnboundedReceiver<()>,
                 rng: Arc<Mutex<StdRng>>,
                 cx: &mut TestAppContext,
             ) -> anyhow::Result<()> {
-                while operations.get() < max_operations {
+                while op_start_signal.next().await.is_some() {
                     let buffer = if client.buffers.is_empty() || rng.lock().gen() {
                         let worktree = if let Some(worktree) =
                             project.read_with(cx, |project, cx| {
@@ -5834,7 +5852,6 @@ mod tests {
                             continue;
                         };
 
-                        operations.set(operations.get() + 1);
                         let (worktree_root_name, project_path) =
                             worktree.read_with(cx, |worktree, _| {
                                 let entry = worktree
@@ -5870,8 +5887,6 @@ mod tests {
                         client.buffers.insert(buffer.clone());
                         buffer
                     } else {
-                        operations.set(operations.get() + 1);
-
                         client
                             .buffers
                             .iter()
@@ -6073,8 +6088,7 @@ mod tests {
                 &mut self,
                 guest_id,
                 project.clone(),
-                operations,
-                max_operations,
+                op_start_signal,
                 rng,
                 &mut cx,
             )