Avoid repeatedly loading/saving the test plan for each iteration

Max Brunsfeld created

Change summary

crates/collab/src/tests/randomized_integration_tests.rs | 64 +++++++---
1 file changed, 43 insertions(+), 21 deletions(-)

Detailed changes

crates/collab/src/tests/randomized_integration_tests.rs 🔗

@@ -33,6 +33,11 @@ use std::{
 };
 use util::ResultExt;
 
+lazy_static::lazy_static! {
+    static ref LOADED_PLAN_JSON: Mutex<Option<Vec<u8>>> = Default::default();
+    static ref DID_SAVE_PLAN_JSON: AtomicBool = Default::default();
+}
+
 #[gpui::test(iterations = 100)]
 async fn test_random_collaboration(
     cx: &mut TestAppContext,
@@ -99,8 +104,14 @@ async fn test_random_collaboration(
     let plan = Arc::new(Mutex::new(TestPlan::new(rng, users, max_operations)));
 
     if let Some(path) = &plan_load_path {
-        eprintln!("loaded plan from path {:?}", path);
-        plan.lock().load(path);
+        let json = LOADED_PLAN_JSON
+            .lock()
+            .get_or_insert_with(|| {
+                eprintln!("loaded test plan from path {:?}", path);
+                std::fs::read(path).unwrap()
+            })
+            .clone();
+        plan.lock().deserialize(json);
     }
 
     let mut clients = Vec::new();
@@ -132,8 +143,10 @@ async fn test_random_collaboration(
     deterministic.run_until_parked();
 
     if let Some(path) = &plan_save_path {
-        eprintln!("saved test plan to path {:?}", path);
-        plan.lock().save(path);
+        if !DID_SAVE_PLAN_JSON.swap(true, SeqCst) {
+            eprintln!("saved test plan to path {:?}", path);
+            std::fs::write(path, plan.lock().serialize()).unwrap();
+        }
     }
 
     for (client, client_cx) in &clients {
@@ -313,28 +326,38 @@ async fn test_random_collaboration(
                     host_buffer.read_with(host_cx, |b, _| b.saved_version().clone());
                 let guest_saved_version =
                     guest_buffer.read_with(client_cx, |b, _| b.saved_version().clone());
-                assert_eq!(guest_saved_version, host_saved_version);
+                assert_eq!(
+                    guest_saved_version, host_saved_version,
+                    "guest saved version does not match host's for path {path:?} in project {project_id}",
+                );
 
                 let host_saved_version_fingerprint =
                     host_buffer.read_with(host_cx, |b, _| b.saved_version_fingerprint());
                 let guest_saved_version_fingerprint =
                     guest_buffer.read_with(client_cx, |b, _| b.saved_version_fingerprint());
                 assert_eq!(
-                    guest_saved_version_fingerprint,
-                    host_saved_version_fingerprint
+                    guest_saved_version_fingerprint, host_saved_version_fingerprint,
+                    "guest's saved fingerprint does not match host's for path {path:?} in project {project_id}",
                 );
 
                 let host_saved_mtime = host_buffer.read_with(host_cx, |b, _| b.saved_mtime());
                 let guest_saved_mtime = guest_buffer.read_with(client_cx, |b, _| b.saved_mtime());
-                assert_eq!(guest_saved_mtime, host_saved_mtime);
+                assert_eq!(
+                    guest_saved_mtime, host_saved_mtime,
+                    "guest's saved mtime does not match host's for path {path:?} in project {project_id}",
+                );
 
                 let host_is_dirty = host_buffer.read_with(host_cx, |b, _| b.is_dirty());
                 let guest_is_dirty = guest_buffer.read_with(client_cx, |b, _| b.is_dirty());
-                assert_eq!(guest_is_dirty, host_is_dirty);
+                assert_eq!(guest_is_dirty, host_is_dirty,
+                    "guest's dirty status does not match host's for path {path:?} in project {project_id}",
+                );
 
                 let host_has_conflict = host_buffer.read_with(host_cx, |b, _| b.has_conflict());
                 let guest_has_conflict = guest_buffer.read_with(client_cx, |b, _| b.has_conflict());
-                assert_eq!(guest_has_conflict, host_has_conflict);
+                assert_eq!(guest_has_conflict, host_has_conflict,
+                    "guest's conflict status does not match host's for path {path:?} in project {project_id}",
+                );
             }
         }
     }
@@ -797,12 +820,12 @@ async fn apply_client_operation(
                 .ok_or(TestError::Inapplicable)?;
 
             log::info!(
-                "{}: saving buffer {:?} in {} project {}{}",
+                "{}: saving buffer {:?} in {} project {}, {}",
                 client.username,
                 full_path,
                 if is_local { "local" } else { "remote" },
                 project_root_name,
-                if detach { ", detaching" } else { ", awaiting" }
+                if detach { "detaching" } else { "awaiting" }
             );
 
             ensure_project_shared(&project, client, cx).await;
@@ -836,13 +859,13 @@ async fn apply_client_operation(
                 .ok_or(TestError::Inapplicable)?;
 
             log::info!(
-                "{}: request LSP {:?} for buffer {:?} in {} project {}{}",
+                "{}: request LSP {:?} for buffer {:?} in {} project {}, {}",
                 client.username,
                 kind,
                 full_path,
                 if is_local { "local" } else { "remote" },
                 project_root_name,
-                if detach { ", detaching" } else { ", awaiting" }
+                if detach { "detaching" } else { "awaiting" }
             );
 
             use futures::{FutureExt as _, TryFutureExt as _};
@@ -888,12 +911,12 @@ async fn apply_client_operation(
                 .ok_or(TestError::Inapplicable)?;
 
             log::info!(
-                "{}: search {} project {} for {:?}{}",
+                "{}: search {} project {} for {:?}, {}",
                 client.username,
                 if is_local { "local" } else { "remote" },
                 project_root_name,
                 query,
-                if detach { ", detaching" } else { ", awaiting" }
+                if detach { "detaching" } else { "awaiting" }
             );
 
             let search = project.update(cx, |project, cx| {
@@ -1137,10 +1160,9 @@ impl TestPlan {
         }
     }
 
-    fn load(&mut self, path: &Path) {
-        let json = std::fs::read_to_string(path).unwrap();
+    fn deserialize(&mut self, json: Vec<u8>) {
+        let stored_operations: Vec<StoredOperation> = serde_json::from_slice(&json).unwrap();
         self.replay = true;
-        let stored_operations: Vec<StoredOperation> = serde_json::from_str(&json).unwrap();
         self.stored_operations = stored_operations
             .iter()
             .cloned()
@@ -1171,7 +1193,7 @@ impl TestPlan {
             .collect()
     }
 
-    fn save(&mut self, path: &Path) {
+    fn serialize(&mut self) -> Vec<u8> {
         // Format each operation as one line
         let mut json = Vec::new();
         json.push(b'[');
@@ -1186,7 +1208,7 @@ impl TestPlan {
             serde_json::to_writer(&mut json, operation).unwrap();
         }
         json.extend_from_slice(b"\n]\n");
-        std::fs::write(path, &json).unwrap();
+        json
     }
 
     fn next_server_operation(