randomized_test_helpers.rs

  1use crate::{
  2    db::{self, NewUserParams, UserId},
  3    rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
  4    tests::{TestClient, TestServer},
  5};
  6use async_trait::async_trait;
  7use futures::StreamExt;
  8use gpui::{BackgroundExecutor, Task, TestAppContext};
  9use parking_lot::Mutex;
 10use rand::prelude::*;
 11use rpc::RECEIVE_TIMEOUT;
 12use serde::{de::DeserializeOwned, Deserialize, Serialize};
 13use settings::SettingsStore;
 14use std::sync::OnceLock;
 15use std::{
 16    env,
 17    path::PathBuf,
 18    rc::Rc,
 19    sync::{
 20        atomic::{AtomicBool, Ordering::SeqCst},
 21        Arc,
 22    },
 23};
 24
 25fn plan_load_path() -> &'static Option<PathBuf> {
 26    static PLAN_LOAD_PATH: OnceLock<Option<PathBuf>> = OnceLock::new();
 27    PLAN_LOAD_PATH.get_or_init(|| path_env_var("LOAD_PLAN"))
 28}
 29
 30fn plan_save_path() -> &'static Option<PathBuf> {
 31    static PLAN_SAVE_PATH: OnceLock<Option<PathBuf>> = OnceLock::new();
 32    PLAN_SAVE_PATH.get_or_init(|| path_env_var("SAVE_PLAN"))
 33}
 34
 35fn max_peers() -> usize {
 36    static MAX_PEERS: OnceLock<usize> = OnceLock::new();
 37    *MAX_PEERS.get_or_init(|| {
 38        env::var("MAX_PEERS")
 39            .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 40            .unwrap_or(3)
 41    })
 42}
 43
 44fn max_operations() -> usize {
 45    static MAX_OPERATIONS: OnceLock<usize> = OnceLock::new();
 46    *MAX_OPERATIONS.get_or_init(|| {
 47        env::var("OPERATIONS")
 48            .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 49            .unwrap_or(10)
 50    })
 51}
 52
 53static LOADED_PLAN_JSON: Mutex<Option<Vec<u8>>> = Mutex::new(None);
 54static LAST_PLAN: Mutex<Option<Box<dyn Send + FnOnce() -> Vec<u8>>>> = Mutex::new(None);
 55
 56struct TestPlan<T: RandomizedTest> {
 57    rng: StdRng,
 58    replay: bool,
 59    stored_operations: Vec<(StoredOperation<T::Operation>, Arc<AtomicBool>)>,
 60    max_operations: usize,
 61    operation_ix: usize,
 62    users: Vec<UserTestPlan>,
 63    next_batch_id: usize,
 64    allow_server_restarts: bool,
 65    allow_client_reconnection: bool,
 66    allow_client_disconnection: bool,
 67}
 68
 69pub struct UserTestPlan {
 70    pub user_id: UserId,
 71    pub username: String,
 72    pub allow_client_reconnection: bool,
 73    pub allow_client_disconnection: bool,
 74    next_root_id: usize,
 75    operation_ix: usize,
 76    online: bool,
 77}
 78
 79#[derive(Clone, Debug, Serialize, Deserialize)]
 80#[serde(untagged)]
 81enum StoredOperation<T> {
 82    Server(ServerOperation),
 83    Client {
 84        user_id: UserId,
 85        batch_id: usize,
 86        operation: T,
 87    },
 88}
 89
 90#[derive(Clone, Debug, Serialize, Deserialize)]
 91enum ServerOperation {
 92    AddConnection {
 93        user_id: UserId,
 94    },
 95    RemoveConnection {
 96        user_id: UserId,
 97    },
 98    BounceConnection {
 99        user_id: UserId,
100    },
101    RestartServer,
102    MutateClients {
103        batch_id: usize,
104        #[serde(skip_serializing)]
105        #[serde(skip_deserializing)]
106        user_ids: Vec<UserId>,
107        quiesce: bool,
108    },
109}
110
111pub enum TestError {
112    Inapplicable,
113    Other(anyhow::Error),
114}
115
116#[async_trait(?Send)]
117pub trait RandomizedTest: 'static + Sized {
118    type Operation: Send + Clone + Serialize + DeserializeOwned;
119
120    fn generate_operation(
121        client: &TestClient,
122        rng: &mut StdRng,
123        plan: &mut UserTestPlan,
124        cx: &TestAppContext,
125    ) -> Self::Operation;
126
127    async fn apply_operation(
128        client: &TestClient,
129        operation: Self::Operation,
130        cx: &mut TestAppContext,
131    ) -> Result<(), TestError>;
132
133    async fn initialize(server: &mut TestServer, users: &[UserTestPlan]);
134
135    async fn on_client_added(_client: &Rc<TestClient>, _cx: &mut TestAppContext) {}
136
137    async fn on_quiesce(server: &mut TestServer, client: &mut [(Rc<TestClient>, TestAppContext)]);
138}
139
140pub async fn run_randomized_test<T: RandomizedTest>(
141    cx: &mut TestAppContext,
142    executor: BackgroundExecutor,
143    rng: StdRng,
144) {
145    let mut server = TestServer::start(executor.clone()).await;
146    let plan = TestPlan::<T>::new(&mut server, rng).await;
147
148    LAST_PLAN.lock().replace({
149        let plan = plan.clone();
150        Box::new(move || plan.lock().serialize())
151    });
152
153    let mut clients = Vec::new();
154    let mut client_tasks = Vec::new();
155    let mut operation_channels = Vec::new();
156    loop {
157        let Some((next_operation, applied)) = plan.lock().next_server_operation(&clients) else {
158            break;
159        };
160        applied.store(true, SeqCst);
161        let did_apply = TestPlan::apply_server_operation(
162            plan.clone(),
163            executor.clone(),
164            &mut server,
165            &mut clients,
166            &mut client_tasks,
167            &mut operation_channels,
168            next_operation,
169            cx,
170        )
171        .await;
172        if !did_apply {
173            applied.store(false, SeqCst);
174        }
175    }
176
177    drop(operation_channels);
178    executor.start_waiting();
179    futures::future::join_all(client_tasks).await;
180    executor.finish_waiting();
181
182    executor.run_until_parked();
183    T::on_quiesce(&mut server, &mut clients).await;
184
185    for (client, cx) in clients {
186        cx.update(|cx| {
187            let store = cx.remove_global::<SettingsStore>();
188            cx.clear_globals();
189            cx.set_global(store);
190            drop(client);
191        });
192    }
193    executor.run_until_parked();
194
195    if let Some(path) = plan_save_path() {
196        eprintln!("saved test plan to path {:?}", path);
197        std::fs::write(path, plan.lock().serialize()).unwrap();
198    }
199}
200
201pub fn save_randomized_test_plan() {
202    if let Some(serialize_plan) = LAST_PLAN.lock().take() {
203        if let Some(path) = plan_save_path() {
204            eprintln!("saved test plan to path {:?}", path);
205            std::fs::write(path, serialize_plan()).unwrap();
206        }
207    }
208}
209
210impl<T: RandomizedTest> TestPlan<T> {
211    pub async fn new(server: &mut TestServer, mut rng: StdRng) -> Arc<Mutex<Self>> {
212        let allow_server_restarts = rng.gen_bool(0.7);
213        let allow_client_reconnection = rng.gen_bool(0.7);
214        let allow_client_disconnection = rng.gen_bool(0.1);
215
216        let mut users = Vec::new();
217        for ix in 0..max_peers() {
218            let username = format!("user-{}", ix + 1);
219            let user_id = server
220                .app_state
221                .db
222                .create_user(
223                    &format!("{username}@example.com"),
224                    false,
225                    NewUserParams {
226                        github_login: username.clone(),
227                        github_user_id: ix as i32,
228                    },
229                )
230                .await
231                .unwrap()
232                .user_id;
233            users.push(UserTestPlan {
234                user_id,
235                username,
236                online: false,
237                next_root_id: 0,
238                operation_ix: 0,
239                allow_client_disconnection,
240                allow_client_reconnection,
241            });
242        }
243
244        T::initialize(server, &users).await;
245
246        let plan = Arc::new(Mutex::new(Self {
247            replay: false,
248            allow_server_restarts,
249            allow_client_reconnection,
250            allow_client_disconnection,
251            stored_operations: Vec::new(),
252            operation_ix: 0,
253            next_batch_id: 0,
254            max_operations: max_operations(),
255            users,
256            rng,
257        }));
258
259        if let Some(path) = plan_load_path() {
260            let json = LOADED_PLAN_JSON
261                .lock()
262                .get_or_insert_with(|| {
263                    eprintln!("loaded test plan from path {:?}", path);
264                    std::fs::read(path).unwrap()
265                })
266                .clone();
267            plan.lock().deserialize(json);
268        }
269
270        plan
271    }
272
273    fn deserialize(&mut self, json: Vec<u8>) {
274        let stored_operations: Vec<StoredOperation<T::Operation>> =
275            serde_json::from_slice(&json).unwrap();
276        self.replay = true;
277        self.stored_operations = stored_operations
278            .iter()
279            .cloned()
280            .enumerate()
281            .map(|(i, mut operation)| {
282                let did_apply = Arc::new(AtomicBool::new(false));
283                if let StoredOperation::Server(ServerOperation::MutateClients {
284                    batch_id: current_batch_id,
285                    user_ids,
286                    ..
287                }) = &mut operation
288                {
289                    assert!(user_ids.is_empty());
290                    user_ids.extend(stored_operations[i + 1..].iter().filter_map(|operation| {
291                        if let StoredOperation::Client {
292                            user_id, batch_id, ..
293                        } = operation
294                        {
295                            if batch_id == current_batch_id {
296                                return Some(user_id);
297                            }
298                        }
299                        None
300                    }));
301                    user_ids.sort_unstable();
302                }
303                (operation, did_apply)
304            })
305            .collect()
306    }
307
308    fn serialize(&mut self) -> Vec<u8> {
309        // Format each operation as one line
310        let mut json = Vec::new();
311        json.push(b'[');
312        for (operation, applied) in &self.stored_operations {
313            if !applied.load(SeqCst) {
314                continue;
315            }
316            if json.len() > 1 {
317                json.push(b',');
318            }
319            json.extend_from_slice(b"\n  ");
320            serde_json::to_writer(&mut json, operation).unwrap();
321        }
322        json.extend_from_slice(b"\n]\n");
323        json
324    }
325
326    fn next_server_operation(
327        &mut self,
328        clients: &[(Rc<TestClient>, TestAppContext)],
329    ) -> Option<(ServerOperation, Arc<AtomicBool>)> {
330        if self.replay {
331            while let Some(stored_operation) = self.stored_operations.get(self.operation_ix) {
332                self.operation_ix += 1;
333                if let (StoredOperation::Server(operation), applied) = stored_operation {
334                    return Some((operation.clone(), applied.clone()));
335                }
336            }
337            None
338        } else {
339            let operation = self.generate_server_operation(clients)?;
340            let applied = Arc::new(AtomicBool::new(false));
341            self.stored_operations
342                .push((StoredOperation::Server(operation.clone()), applied.clone()));
343            Some((operation, applied))
344        }
345    }
346
347    fn next_client_operation(
348        &mut self,
349        client: &TestClient,
350        current_batch_id: usize,
351        cx: &TestAppContext,
352    ) -> Option<(T::Operation, Arc<AtomicBool>)> {
353        let current_user_id = client.current_user_id(cx);
354        let user_ix = self
355            .users
356            .iter()
357            .position(|user| user.user_id == current_user_id)
358            .unwrap();
359        let user_plan = &mut self.users[user_ix];
360
361        if self.replay {
362            while let Some(stored_operation) = self.stored_operations.get(user_plan.operation_ix) {
363                user_plan.operation_ix += 1;
364                if let (
365                    StoredOperation::Client {
366                        user_id, operation, ..
367                    },
368                    applied,
369                ) = stored_operation
370                {
371                    if user_id == &current_user_id {
372                        return Some((operation.clone(), applied.clone()));
373                    }
374                }
375            }
376            None
377        } else {
378            if self.operation_ix == self.max_operations {
379                return None;
380            }
381            self.operation_ix += 1;
382            let operation = T::generate_operation(
383                client,
384                &mut self.rng,
385                self.users
386                    .iter_mut()
387                    .find(|user| user.user_id == current_user_id)
388                    .unwrap(),
389                cx,
390            );
391            let applied = Arc::new(AtomicBool::new(false));
392            self.stored_operations.push((
393                StoredOperation::Client {
394                    user_id: current_user_id,
395                    batch_id: current_batch_id,
396                    operation: operation.clone(),
397                },
398                applied.clone(),
399            ));
400            Some((operation, applied))
401        }
402    }
403
404    fn generate_server_operation(
405        &mut self,
406        clients: &[(Rc<TestClient>, TestAppContext)],
407    ) -> Option<ServerOperation> {
408        if self.operation_ix == self.max_operations {
409            return None;
410        }
411
412        Some(loop {
413            break match self.rng.gen_range(0..100) {
414                0..=29 if clients.len() < self.users.len() => {
415                    let user = self
416                        .users
417                        .iter()
418                        .filter(|u| !u.online)
419                        .choose(&mut self.rng)
420                        .unwrap();
421                    self.operation_ix += 1;
422                    ServerOperation::AddConnection {
423                        user_id: user.user_id,
424                    }
425                }
426                30..=34 if clients.len() > 1 && self.allow_client_disconnection => {
427                    let (client, cx) = &clients[self.rng.gen_range(0..clients.len())];
428                    let user_id = client.current_user_id(cx);
429                    self.operation_ix += 1;
430                    ServerOperation::RemoveConnection { user_id }
431                }
432                35..=39 if clients.len() > 1 && self.allow_client_reconnection => {
433                    let (client, cx) = &clients[self.rng.gen_range(0..clients.len())];
434                    let user_id = client.current_user_id(cx);
435                    self.operation_ix += 1;
436                    ServerOperation::BounceConnection { user_id }
437                }
438                40..=44 if self.allow_server_restarts && clients.len() > 1 => {
439                    self.operation_ix += 1;
440                    ServerOperation::RestartServer
441                }
442                _ if !clients.is_empty() => {
443                    let count = self
444                        .rng
445                        .gen_range(1..10)
446                        .min(self.max_operations - self.operation_ix);
447                    let batch_id = util::post_inc(&mut self.next_batch_id);
448                    let mut user_ids = (0..count)
449                        .map(|_| {
450                            let ix = self.rng.gen_range(0..clients.len());
451                            let (client, cx) = &clients[ix];
452                            client.current_user_id(cx)
453                        })
454                        .collect::<Vec<_>>();
455                    user_ids.sort_unstable();
456                    ServerOperation::MutateClients {
457                        user_ids,
458                        batch_id,
459                        quiesce: self.rng.gen_bool(0.7),
460                    }
461                }
462                _ => continue,
463            };
464        })
465    }
466
467    #[allow(clippy::too_many_arguments)]
468    async fn apply_server_operation(
469        plan: Arc<Mutex<Self>>,
470        deterministic: BackgroundExecutor,
471        server: &mut TestServer,
472        clients: &mut Vec<(Rc<TestClient>, TestAppContext)>,
473        client_tasks: &mut Vec<Task<()>>,
474        operation_channels: &mut Vec<futures::channel::mpsc::UnboundedSender<usize>>,
475        operation: ServerOperation,
476        cx: &mut TestAppContext,
477    ) -> bool {
478        match operation {
479            ServerOperation::AddConnection { user_id } => {
480                let username;
481                {
482                    let mut plan = plan.lock();
483                    let user = plan.user(user_id);
484                    if user.online {
485                        return false;
486                    }
487                    user.online = true;
488                    username = user.username.clone();
489                };
490                log::info!("adding new connection for {}", username);
491
492                let mut client_cx = cx.new_app();
493
494                let (operation_tx, operation_rx) = futures::channel::mpsc::unbounded();
495                let client = Rc::new(server.create_client(&mut client_cx, &username).await);
496                operation_channels.push(operation_tx);
497                clients.push((client.clone(), client_cx.clone()));
498
499                let foreground_executor = client_cx.foreground_executor().clone();
500                let simulate_client =
501                    Self::simulate_client(plan.clone(), client, operation_rx, client_cx);
502                client_tasks.push(foreground_executor.spawn(simulate_client));
503
504                log::info!("added connection for {}", username);
505            }
506
507            ServerOperation::RemoveConnection {
508                user_id: removed_user_id,
509            } => {
510                log::info!("simulating full disconnection of user {}", removed_user_id);
511                let client_ix = clients
512                    .iter()
513                    .position(|(client, cx)| client.current_user_id(cx) == removed_user_id);
514                let Some(client_ix) = client_ix else {
515                    return false;
516                };
517                let user_connection_ids = server
518                    .connection_pool
519                    .lock()
520                    .user_connection_ids(removed_user_id)
521                    .collect::<Vec<_>>();
522                assert_eq!(user_connection_ids.len(), 1);
523                let removed_peer_id = user_connection_ids[0].into();
524                let (client, client_cx) = clients.remove(client_ix);
525                let client_task = client_tasks.remove(client_ix);
526                operation_channels.remove(client_ix);
527                server.forbid_connections();
528                server.disconnect_client(removed_peer_id);
529                deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
530                deterministic.start_waiting();
531                log::info!("waiting for user {} to exit...", removed_user_id);
532                client_task.await;
533                deterministic.finish_waiting();
534                server.allow_connections();
535
536                for project in client.dev_server_projects().iter() {
537                    project.read_with(&client_cx, |project, _| {
538                        assert!(
539                            project.is_disconnected(),
540                            "project {:?} should be read only",
541                            project.remote_id()
542                        )
543                    });
544                }
545
546                for (client, cx) in clients {
547                    let contacts = server
548                        .app_state
549                        .db
550                        .get_contacts(client.current_user_id(cx))
551                        .await
552                        .unwrap();
553                    let pool = server.connection_pool.lock();
554                    for contact in contacts {
555                        if let db::Contact::Accepted { user_id, busy, .. } = contact {
556                            if user_id == removed_user_id {
557                                assert!(!pool.is_user_online(user_id));
558                                assert!(!busy);
559                            }
560                        }
561                    }
562                }
563
564                log::info!("{} removed", client.username);
565                plan.lock().user(removed_user_id).online = false;
566                client_cx.update(|cx| {
567                    cx.clear_globals();
568                    drop(client);
569                });
570            }
571
572            ServerOperation::BounceConnection { user_id } => {
573                log::info!("simulating temporary disconnection of user {}", user_id);
574                let user_connection_ids = server
575                    .connection_pool
576                    .lock()
577                    .user_connection_ids(user_id)
578                    .collect::<Vec<_>>();
579                if user_connection_ids.is_empty() {
580                    return false;
581                }
582                assert_eq!(user_connection_ids.len(), 1);
583                let peer_id = user_connection_ids[0].into();
584                server.disconnect_client(peer_id);
585                deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
586            }
587
588            ServerOperation::RestartServer => {
589                log::info!("simulating server restart");
590                server.reset().await;
591                deterministic.advance_clock(RECEIVE_TIMEOUT);
592                server.start().await.unwrap();
593                deterministic.advance_clock(CLEANUP_TIMEOUT);
594                let environment = &server.app_state.config.zed_environment;
595                let (stale_room_ids, _) = server
596                    .app_state
597                    .db
598                    .stale_server_resource_ids(environment, server.id())
599                    .await
600                    .unwrap();
601                assert_eq!(stale_room_ids, vec![]);
602            }
603
604            ServerOperation::MutateClients {
605                user_ids,
606                batch_id,
607                quiesce,
608            } => {
609                let mut applied = false;
610                for user_id in user_ids {
611                    let client_ix = clients
612                        .iter()
613                        .position(|(client, cx)| client.current_user_id(cx) == user_id);
614                    let Some(client_ix) = client_ix else { continue };
615                    applied = true;
616                    if let Err(err) = operation_channels[client_ix].unbounded_send(batch_id) {
617                        log::error!("error signaling user {user_id}: {err}");
618                    }
619                }
620
621                if quiesce && applied {
622                    deterministic.run_until_parked();
623                    T::on_quiesce(server, clients).await;
624                }
625
626                return applied;
627            }
628        }
629        true
630    }
631
632    async fn simulate_client(
633        plan: Arc<Mutex<Self>>,
634        client: Rc<TestClient>,
635        mut operation_rx: futures::channel::mpsc::UnboundedReceiver<usize>,
636        mut cx: TestAppContext,
637    ) {
638        T::on_client_added(&client, &mut cx).await;
639
640        while let Some(batch_id) = operation_rx.next().await {
641            let Some((operation, applied)) =
642                plan.lock().next_client_operation(&client, batch_id, &cx)
643            else {
644                break;
645            };
646            applied.store(true, SeqCst);
647            match T::apply_operation(&client, operation, &mut cx).await {
648                Ok(()) => {}
649                Err(TestError::Inapplicable) => {
650                    applied.store(false, SeqCst);
651                    log::info!("skipped operation");
652                }
653                Err(TestError::Other(error)) => {
654                    log::error!("{} error: {}", client.username, error);
655                }
656            }
657            cx.executor().simulate_random_delay().await;
658        }
659        log::info!("{}: done", client.username);
660    }
661
662    fn user(&mut self, user_id: UserId) -> &mut UserTestPlan {
663        self.users
664            .iter_mut()
665            .find(|user| user.user_id == user_id)
666            .unwrap()
667    }
668}
669
670impl UserTestPlan {
671    pub fn next_root_dir_name(&mut self) -> String {
672        let user_id = self.user_id;
673        let root_id = util::post_inc(&mut self.next_root_id);
674        format!("dir-{user_id}-{root_id}")
675    }
676}
677
678impl From<anyhow::Error> for TestError {
679    fn from(value: anyhow::Error) -> Self {
680        Self::Other(value)
681    }
682}
683
684fn path_env_var(name: &str) -> Option<PathBuf> {
685    let value = env::var(name).ok()?;
686    let mut path = PathBuf::from(value);
687    if path.is_relative() {
688        let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
689        abs_path.pop();
690        abs_path.pop();
691        abs_path.push(path);
692        path = abs_path
693    }
694    Some(path)
695}