randomized_test_helpers.rs

  1use crate::{TestClient, TestServer};
  2use async_trait::async_trait;
  3use collab::{
  4    db::{self, NewUserParams, UserId},
  5    rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
  6};
  7use futures::StreamExt;
  8use gpui::{BackgroundExecutor, Task, TestAppContext};
  9use parking_lot::Mutex;
 10use rand::prelude::*;
 11use rpc::RECEIVE_TIMEOUT;
 12use serde::{Deserialize, Serialize, de::DeserializeOwned};
 13use settings::SettingsStore;
 14use std::sync::OnceLock;
 15use std::{
 16    env,
 17    path::PathBuf,
 18    rc::Rc,
 19    sync::{
 20        Arc,
 21        atomic::{AtomicBool, Ordering::SeqCst},
 22    },
 23};
 24
 25fn plan_load_path() -> &'static Option<PathBuf> {
 26    static PLAN_LOAD_PATH: OnceLock<Option<PathBuf>> = OnceLock::new();
 27    PLAN_LOAD_PATH.get_or_init(|| path_env_var("LOAD_PLAN"))
 28}
 29
 30fn plan_save_path() -> &'static Option<PathBuf> {
 31    static PLAN_SAVE_PATH: OnceLock<Option<PathBuf>> = OnceLock::new();
 32    PLAN_SAVE_PATH.get_or_init(|| path_env_var("SAVE_PLAN"))
 33}
 34
 35fn max_peers() -> usize {
 36    static MAX_PEERS: OnceLock<usize> = OnceLock::new();
 37    *MAX_PEERS.get_or_init(|| {
 38        env::var("MAX_PEERS")
 39            .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 40            .unwrap_or(3)
 41    })
 42}
 43
 44fn max_operations() -> usize {
 45    static MAX_OPERATIONS: OnceLock<usize> = OnceLock::new();
 46    *MAX_OPERATIONS.get_or_init(|| {
 47        env::var("OPERATIONS")
 48            .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 49            .unwrap_or(10)
 50    })
 51}
 52
 53static LOADED_PLAN_JSON: Mutex<Option<Vec<u8>>> = Mutex::new(None);
 54static LAST_PLAN: Mutex<Option<Box<dyn Send + FnOnce() -> Vec<u8>>>> = Mutex::new(None);
 55
 56struct TestPlan<T: RandomizedTest> {
 57    rng: StdRng,
 58    replay: bool,
 59    stored_operations: Vec<(StoredOperation<T::Operation>, Arc<AtomicBool>)>,
 60    max_operations: usize,
 61    operation_ix: usize,
 62    users: Vec<UserTestPlan>,
 63    next_batch_id: usize,
 64    allow_server_restarts: bool,
 65    allow_client_reconnection: bool,
 66    allow_client_disconnection: bool,
 67}
 68
 69pub struct UserTestPlan {
 70    pub user_id: UserId,
 71    pub username: String,
 72    pub allow_client_disconnection: bool,
 73    next_root_id: usize,
 74    operation_ix: usize,
 75    online: bool,
 76}
 77
 78#[derive(Clone, Debug, Serialize, Deserialize)]
 79#[serde(untagged)]
 80enum StoredOperation<T> {
 81    Server(ServerOperation),
 82    Client {
 83        user_id: UserId,
 84        batch_id: usize,
 85        operation: T,
 86    },
 87}
 88
 89#[derive(Clone, Debug, Serialize, Deserialize)]
 90enum ServerOperation {
 91    AddConnection {
 92        user_id: UserId,
 93    },
 94    RemoveConnection {
 95        user_id: UserId,
 96    },
 97    BounceConnection {
 98        user_id: UserId,
 99    },
100    RestartServer,
101    MutateClients {
102        batch_id: usize,
103        #[serde(skip_serializing)]
104        #[serde(skip_deserializing)]
105        user_ids: Vec<UserId>,
106        quiesce: bool,
107    },
108}
109
110pub enum TestError {
111    Inapplicable,
112    Other(anyhow::Error),
113}
114
115#[async_trait(?Send)]
116pub trait RandomizedTest: 'static + Sized {
117    type Operation: Send + Clone + Serialize + DeserializeOwned;
118
119    fn generate_operation(
120        client: &TestClient,
121        rng: &mut StdRng,
122        plan: &mut UserTestPlan,
123        cx: &TestAppContext,
124    ) -> Self::Operation;
125
126    async fn apply_operation(
127        client: &TestClient,
128        operation: Self::Operation,
129        cx: &mut TestAppContext,
130    ) -> Result<(), TestError>;
131
132    async fn initialize(server: &mut TestServer, users: &[UserTestPlan]);
133
134    async fn on_client_added(_client: &Rc<TestClient>, _cx: &mut TestAppContext) {}
135
136    async fn on_quiesce(server: &mut TestServer, client: &mut [(Rc<TestClient>, TestAppContext)]);
137}
138
139pub async fn run_randomized_test<T: RandomizedTest>(
140    cx: &mut TestAppContext,
141    executor: BackgroundExecutor,
142    rng: StdRng,
143) {
144    let mut server = TestServer::start(executor.clone()).await;
145    let plan = TestPlan::<T>::new(&mut server, rng).await;
146
147    LAST_PLAN.lock().replace({
148        let plan = plan.clone();
149        Box::new(move || plan.lock().serialize())
150    });
151
152    let mut clients = Vec::new();
153    let mut client_tasks = Vec::new();
154    let mut operation_channels = Vec::new();
155    loop {
156        let Some((next_operation, applied)) = plan.lock().next_server_operation(&clients) else {
157            break;
158        };
159        applied.store(true, SeqCst);
160        let did_apply = TestPlan::apply_server_operation(
161            plan.clone(),
162            executor.clone(),
163            &mut server,
164            &mut clients,
165            &mut client_tasks,
166            &mut operation_channels,
167            next_operation,
168            cx,
169        )
170        .await;
171        if !did_apply {
172            applied.store(false, SeqCst);
173        }
174    }
175
176    drop(operation_channels);
177    futures::future::join_all(client_tasks).await;
178
179    executor.run_until_parked();
180    T::on_quiesce(&mut server, &mut clients).await;
181
182    for (client, cx) in clients {
183        cx.update(|cx| {
184            for window in cx.windows() {
185                window
186                    .update(cx, |_, window, _| window.remove_window())
187                    .ok();
188            }
189        });
190        cx.update(|cx| {
191            let settings = cx.remove_global::<SettingsStore>();
192            cx.clear_globals();
193            cx.set_global(settings);
194            theme::init(theme::LoadThemes::JustBase, cx);
195            drop(client);
196        });
197        executor.run_until_parked();
198    }
199
200    if let Some(path) = plan_save_path() {
201        eprintln!("saved test plan to path {:?}", path);
202        std::fs::write(path, plan.lock().serialize()).unwrap();
203    }
204}
205
206pub fn save_randomized_test_plan() {
207    if let Some(serialize_plan) = LAST_PLAN.lock().take()
208        && let Some(path) = plan_save_path()
209    {
210        eprintln!("saved test plan to path {:?}", path);
211        std::fs::write(path, serialize_plan()).unwrap();
212    }
213}
214
215impl<T: RandomizedTest> TestPlan<T> {
216    pub async fn new(server: &mut TestServer, mut rng: StdRng) -> Arc<Mutex<Self>> {
217        let allow_server_restarts = rng.random_bool(0.7);
218        let allow_client_reconnection = rng.random_bool(0.7);
219        let allow_client_disconnection = rng.random_bool(0.1);
220
221        let mut users = Vec::new();
222        for ix in 0..max_peers() {
223            let username = format!("user-{}", ix + 1);
224            let user_id = server
225                .app_state
226                .db
227                .create_user(
228                    &format!("{username}@example.com"),
229                    None,
230                    false,
231                    NewUserParams {
232                        github_login: username.clone(),
233                        github_user_id: ix as i32,
234                    },
235                )
236                .await
237                .unwrap()
238                .user_id;
239            users.push(UserTestPlan {
240                user_id,
241                username,
242                online: false,
243                next_root_id: 0,
244                operation_ix: 0,
245                allow_client_disconnection,
246            });
247        }
248
249        T::initialize(server, &users).await;
250
251        let plan = Arc::new(Mutex::new(Self {
252            replay: false,
253            allow_server_restarts,
254            allow_client_reconnection,
255            allow_client_disconnection,
256            stored_operations: Vec::new(),
257            operation_ix: 0,
258            next_batch_id: 0,
259            max_operations: max_operations(),
260            users,
261            rng,
262        }));
263
264        if let Some(path) = plan_load_path() {
265            let json = LOADED_PLAN_JSON
266                .lock()
267                .get_or_insert_with(|| {
268                    eprintln!("loaded test plan from path {:?}", path);
269                    std::fs::read(path).unwrap()
270                })
271                .clone();
272            plan.lock().deserialize(json);
273        }
274
275        plan
276    }
277
278    fn deserialize(&mut self, json: Vec<u8>) {
279        let stored_operations: Vec<StoredOperation<T::Operation>> =
280            serde_json::from_slice(&json).unwrap();
281        self.replay = true;
282        self.stored_operations = stored_operations
283            .iter()
284            .cloned()
285            .enumerate()
286            .map(|(i, mut operation)| {
287                let did_apply = Arc::new(AtomicBool::new(false));
288                if let StoredOperation::Server(ServerOperation::MutateClients {
289                    batch_id: current_batch_id,
290                    user_ids,
291                    ..
292                }) = &mut operation
293                {
294                    assert!(user_ids.is_empty());
295                    user_ids.extend(stored_operations[i + 1..].iter().filter_map(|operation| {
296                        if let StoredOperation::Client {
297                            user_id, batch_id, ..
298                        } = operation
299                            && batch_id == current_batch_id
300                        {
301                            return Some(user_id);
302                        }
303                        None
304                    }));
305                    user_ids.sort_unstable();
306                }
307                (operation, did_apply)
308            })
309            .collect()
310    }
311
312    fn serialize(&mut self) -> Vec<u8> {
313        // Format each operation as one line
314        let mut json = Vec::new();
315        json.push(b'[');
316        for (operation, applied) in &self.stored_operations {
317            if !applied.load(SeqCst) {
318                continue;
319            }
320            if json.len() > 1 {
321                json.push(b',');
322            }
323            json.extend_from_slice(b"\n  ");
324            serde_json::to_writer(&mut json, operation).unwrap();
325        }
326        json.extend_from_slice(b"\n]\n");
327        json
328    }
329
330    fn next_server_operation(
331        &mut self,
332        clients: &[(Rc<TestClient>, TestAppContext)],
333    ) -> Option<(ServerOperation, Arc<AtomicBool>)> {
334        if self.replay {
335            while let Some(stored_operation) = self.stored_operations.get(self.operation_ix) {
336                self.operation_ix += 1;
337                if let (StoredOperation::Server(operation), applied) = stored_operation {
338                    return Some((operation.clone(), applied.clone()));
339                }
340            }
341            None
342        } else {
343            let operation = self.generate_server_operation(clients)?;
344            let applied = Arc::new(AtomicBool::new(false));
345            self.stored_operations
346                .push((StoredOperation::Server(operation.clone()), applied.clone()));
347            Some((operation, applied))
348        }
349    }
350
351    fn next_client_operation(
352        &mut self,
353        client: &TestClient,
354        current_batch_id: usize,
355        cx: &TestAppContext,
356    ) -> Option<(T::Operation, Arc<AtomicBool>)> {
357        let current_user_id = client.current_user_id(cx);
358        let user_ix = self
359            .users
360            .iter()
361            .position(|user| user.user_id == current_user_id)
362            .unwrap();
363        let user_plan = &mut self.users[user_ix];
364
365        if self.replay {
366            while let Some(stored_operation) = self.stored_operations.get(user_plan.operation_ix) {
367                user_plan.operation_ix += 1;
368                if let (
369                    StoredOperation::Client {
370                        user_id, operation, ..
371                    },
372                    applied,
373                ) = stored_operation
374                    && user_id == &current_user_id
375                {
376                    return Some((operation.clone(), applied.clone()));
377                }
378            }
379            None
380        } else {
381            if self.operation_ix == self.max_operations {
382                return None;
383            }
384            self.operation_ix += 1;
385            let operation = T::generate_operation(
386                client,
387                &mut self.rng,
388                self.users
389                    .iter_mut()
390                    .find(|user| user.user_id == current_user_id)
391                    .unwrap(),
392                cx,
393            );
394            let applied = Arc::new(AtomicBool::new(false));
395            self.stored_operations.push((
396                StoredOperation::Client {
397                    user_id: current_user_id,
398                    batch_id: current_batch_id,
399                    operation: operation.clone(),
400                },
401                applied.clone(),
402            ));
403            Some((operation, applied))
404        }
405    }
406
407    fn generate_server_operation(
408        &mut self,
409        clients: &[(Rc<TestClient>, TestAppContext)],
410    ) -> Option<ServerOperation> {
411        if self.operation_ix == self.max_operations {
412            return None;
413        }
414
415        Some(loop {
416            break match self.rng.random_range(0..100) {
417                0..=29 if clients.len() < self.users.len() => {
418                    let user = self
419                        .users
420                        .iter()
421                        .filter(|u| !u.online)
422                        .choose(&mut self.rng)
423                        .unwrap();
424                    self.operation_ix += 1;
425                    ServerOperation::AddConnection {
426                        user_id: user.user_id,
427                    }
428                }
429                30..=34 if clients.len() > 1 && self.allow_client_disconnection => {
430                    let (client, cx) = &clients[self.rng.random_range(0..clients.len())];
431                    let user_id = client.current_user_id(cx);
432                    self.operation_ix += 1;
433                    ServerOperation::RemoveConnection { user_id }
434                }
435                35..=39 if clients.len() > 1 && self.allow_client_reconnection => {
436                    let (client, cx) = &clients[self.rng.random_range(0..clients.len())];
437                    let user_id = client.current_user_id(cx);
438                    self.operation_ix += 1;
439                    ServerOperation::BounceConnection { user_id }
440                }
441                40..=44 if self.allow_server_restarts && clients.len() > 1 => {
442                    self.operation_ix += 1;
443                    ServerOperation::RestartServer
444                }
445                _ if !clients.is_empty() => {
446                    let count = self
447                        .rng
448                        .random_range(1..10)
449                        .min(self.max_operations - self.operation_ix);
450                    let batch_id = util::post_inc(&mut self.next_batch_id);
451                    let mut user_ids = (0..count)
452                        .map(|_| {
453                            let ix = self.rng.random_range(0..clients.len());
454                            let (client, cx) = &clients[ix];
455                            client.current_user_id(cx)
456                        })
457                        .collect::<Vec<_>>();
458                    user_ids.sort_unstable();
459                    ServerOperation::MutateClients {
460                        user_ids,
461                        batch_id,
462                        quiesce: self.rng.random_bool(0.7),
463                    }
464                }
465                _ => continue,
466            };
467        })
468    }
469
470    async fn apply_server_operation(
471        plan: Arc<Mutex<Self>>,
472        deterministic: BackgroundExecutor,
473        server: &mut TestServer,
474        clients: &mut Vec<(Rc<TestClient>, TestAppContext)>,
475        client_tasks: &mut Vec<Task<()>>,
476        operation_channels: &mut Vec<futures::channel::mpsc::UnboundedSender<usize>>,
477        operation: ServerOperation,
478        cx: &mut TestAppContext,
479    ) -> bool {
480        match operation {
481            ServerOperation::AddConnection { user_id } => {
482                let username;
483                {
484                    let mut plan = plan.lock();
485                    let user = plan.user(user_id);
486                    if user.online {
487                        return false;
488                    }
489                    user.online = true;
490                    username = user.username.clone();
491                };
492                log::info!("adding new connection for {}", username);
493
494                let mut client_cx = cx.new_app();
495
496                let (operation_tx, operation_rx) = futures::channel::mpsc::unbounded();
497                let client = Rc::new(server.create_client(&mut client_cx, &username).await);
498                operation_channels.push(operation_tx);
499                clients.push((client.clone(), client_cx.clone()));
500
501                let foreground_executor = client_cx.foreground_executor().clone();
502                let simulate_client =
503                    Self::simulate_client(plan.clone(), client, operation_rx, client_cx);
504                client_tasks.push(foreground_executor.spawn(simulate_client));
505
506                log::info!("added connection for {}", username);
507            }
508
509            ServerOperation::RemoveConnection {
510                user_id: removed_user_id,
511            } => {
512                log::info!("simulating full disconnection of user {}", removed_user_id);
513                let client_ix = clients
514                    .iter()
515                    .position(|(client, cx)| client.current_user_id(cx) == removed_user_id);
516                let Some(client_ix) = client_ix else {
517                    return false;
518                };
519                let user_connection_ids = server
520                    .connection_pool
521                    .lock()
522                    .user_connection_ids(removed_user_id)
523                    .collect::<Vec<_>>();
524                assert_eq!(user_connection_ids.len(), 1);
525                let removed_peer_id = user_connection_ids[0].into();
526                let (client, client_cx) = clients.remove(client_ix);
527                let client_task = client_tasks.remove(client_ix);
528                operation_channels.remove(client_ix);
529                server.forbid_connections();
530                server.disconnect_client(removed_peer_id);
531                deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
532                log::info!("waiting for user {} to exit...", removed_user_id);
533                client_task.await;
534                server.allow_connections();
535
536                for project in client.dev_server_projects().iter() {
537                    project.read_with(&client_cx, |project, cx| {
538                        assert!(
539                            project.is_disconnected(cx),
540                            "project {:?} should be read only",
541                            project.remote_id()
542                        )
543                    });
544                }
545
546                for (client, cx) in clients {
547                    let contacts = server
548                        .app_state
549                        .db
550                        .get_contacts(client.current_user_id(cx))
551                        .await
552                        .unwrap();
553                    let pool = server.connection_pool.lock();
554                    for contact in contacts {
555                        if let db::Contact::Accepted { user_id, busy, .. } = contact
556                            && user_id == removed_user_id
557                        {
558                            assert!(!pool.is_user_online(user_id));
559                            assert!(!busy);
560                        }
561                    }
562                }
563
564                log::info!("{} removed", client.username);
565                plan.lock().user(removed_user_id).online = false;
566                client_cx.update(|cx| {
567                    for window in cx.windows() {
568                        window
569                            .update(cx, |_, window, _| window.remove_window())
570                            .ok();
571                    }
572                });
573                client_cx.update(|cx| {
574                    cx.clear_globals();
575                    drop(client);
576                });
577            }
578
579            ServerOperation::BounceConnection { user_id } => {
580                log::info!("simulating temporary disconnection of user {}", user_id);
581                let user_connection_ids = server
582                    .connection_pool
583                    .lock()
584                    .user_connection_ids(user_id)
585                    .collect::<Vec<_>>();
586                if user_connection_ids.is_empty() {
587                    return false;
588                }
589                assert_eq!(user_connection_ids.len(), 1);
590                let peer_id = user_connection_ids[0].into();
591                server.disconnect_client(peer_id);
592                deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
593            }
594
595            ServerOperation::RestartServer => {
596                log::info!("simulating server restart");
597                server.reset().await;
598                deterministic.advance_clock(RECEIVE_TIMEOUT);
599                server.start().await.unwrap();
600                deterministic.advance_clock(CLEANUP_TIMEOUT);
601                let environment = &server.app_state.config.zed_environment;
602                let (stale_room_ids, _) = server
603                    .app_state
604                    .db
605                    .stale_server_resource_ids(environment, server.id())
606                    .await
607                    .unwrap();
608                assert_eq!(stale_room_ids, vec![]);
609            }
610
611            ServerOperation::MutateClients {
612                user_ids,
613                batch_id,
614                quiesce,
615            } => {
616                let mut applied = false;
617                for user_id in user_ids {
618                    let client_ix = clients
619                        .iter()
620                        .position(|(client, cx)| client.current_user_id(cx) == user_id);
621                    let Some(client_ix) = client_ix else { continue };
622                    applied = true;
623                    if let Err(err) = operation_channels[client_ix].unbounded_send(batch_id) {
624                        log::error!("error signaling user {user_id}: {err}");
625                    }
626                }
627
628                if quiesce && applied {
629                    deterministic.run_until_parked();
630                    T::on_quiesce(server, clients).await;
631                }
632
633                return applied;
634            }
635        }
636        true
637    }
638
639    async fn simulate_client(
640        plan: Arc<Mutex<Self>>,
641        client: Rc<TestClient>,
642        mut operation_rx: futures::channel::mpsc::UnboundedReceiver<usize>,
643        mut cx: TestAppContext,
644    ) {
645        T::on_client_added(&client, &mut cx).await;
646
647        while let Some(batch_id) = operation_rx.next().await {
648            let Some((operation, applied)) =
649                plan.lock().next_client_operation(&client, batch_id, &cx)
650            else {
651                break;
652            };
653            applied.store(true, SeqCst);
654            match T::apply_operation(&client, operation, &mut cx).await {
655                Ok(()) => {}
656                Err(TestError::Inapplicable) => {
657                    applied.store(false, SeqCst);
658                    log::info!("skipped operation");
659                }
660                Err(TestError::Other(error)) => {
661                    log::error!("{} error: {}", client.username, error);
662                }
663            }
664            cx.executor().simulate_random_delay().await;
665        }
666        log::info!("{}: done", client.username);
667    }
668
669    fn user(&mut self, user_id: UserId) -> &mut UserTestPlan {
670        self.users
671            .iter_mut()
672            .find(|user| user.user_id == user_id)
673            .unwrap()
674    }
675}
676
677impl UserTestPlan {
678    pub fn next_root_dir_name(&mut self) -> String {
679        let user_id = self.user_id;
680        let root_id = util::post_inc(&mut self.next_root_id);
681        format!("dir-{user_id}-{root_id}")
682    }
683}
684
685impl From<anyhow::Error> for TestError {
686    fn from(value: anyhow::Error) -> Self {
687        Self::Other(value)
688    }
689}
690
691fn path_env_var(name: &str) -> Option<PathBuf> {
692    let value = env::var(name).ok()?;
693    let mut path = PathBuf::from(value);
694    if path.is_relative() {
695        let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
696        abs_path.pop();
697        abs_path.pop();
698        abs_path.push(path);
699        path = abs_path
700    }
701    Some(path)
702}