randomized_test_helpers.rs

  1use crate::{
  2    db::{self, NewUserParams, UserId},
  3    rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
  4    tests::{TestClient, TestServer},
  5};
  6use async_trait::async_trait;
  7use futures::StreamExt;
  8use gpui::{executor::Deterministic, Task, TestAppContext};
  9use parking_lot::Mutex;
 10use rand::prelude::*;
 11use rpc::RECEIVE_TIMEOUT;
 12use serde::{de::DeserializeOwned, Deserialize, Serialize};
 13use settings::SettingsStore;
 14use std::{
 15    env,
 16    path::PathBuf,
 17    rc::Rc,
 18    sync::{
 19        atomic::{AtomicBool, Ordering::SeqCst},
 20        Arc,
 21    },
 22};
 23
 24lazy_static::lazy_static! {
 25    static ref PLAN_LOAD_PATH: Option<PathBuf> = path_env_var("LOAD_PLAN");
 26    static ref PLAN_SAVE_PATH: Option<PathBuf> = path_env_var("SAVE_PLAN");
 27    static ref MAX_PEERS: usize = env::var("MAX_PEERS")
 28        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 29        .unwrap_or(3);
 30    static ref MAX_OPERATIONS: usize = env::var("OPERATIONS")
 31        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 32        .unwrap_or(10);
 33
 34}
 35
 36static LOADED_PLAN_JSON: Mutex<Option<Vec<u8>>> = Mutex::new(None);
 37static LAST_PLAN: Mutex<Option<Box<dyn Send + FnOnce() -> Vec<u8>>>> = Mutex::new(None);
 38
 39struct TestPlan<T: RandomizedTest> {
 40    rng: StdRng,
 41    replay: bool,
 42    stored_operations: Vec<(StoredOperation<T::Operation>, Arc<AtomicBool>)>,
 43    max_operations: usize,
 44    operation_ix: usize,
 45    users: Vec<UserTestPlan>,
 46    next_batch_id: usize,
 47    allow_server_restarts: bool,
 48    allow_client_reconnection: bool,
 49    allow_client_disconnection: bool,
 50}
 51
 52pub struct UserTestPlan {
 53    pub user_id: UserId,
 54    pub username: String,
 55    pub allow_client_reconnection: bool,
 56    pub allow_client_disconnection: bool,
 57    next_root_id: usize,
 58    operation_ix: usize,
 59    online: bool,
 60}
 61
 62#[derive(Clone, Debug, Serialize, Deserialize)]
 63#[serde(untagged)]
 64enum StoredOperation<T> {
 65    Server(ServerOperation),
 66    Client {
 67        user_id: UserId,
 68        batch_id: usize,
 69        operation: T,
 70    },
 71}
 72
 73#[derive(Clone, Debug, Serialize, Deserialize)]
 74enum ServerOperation {
 75    AddConnection {
 76        user_id: UserId,
 77    },
 78    RemoveConnection {
 79        user_id: UserId,
 80    },
 81    BounceConnection {
 82        user_id: UserId,
 83    },
 84    RestartServer,
 85    MutateClients {
 86        batch_id: usize,
 87        #[serde(skip_serializing)]
 88        #[serde(skip_deserializing)]
 89        user_ids: Vec<UserId>,
 90        quiesce: bool,
 91    },
 92}
 93
 94pub enum TestError {
 95    Inapplicable,
 96    Other(anyhow::Error),
 97}
 98
 99#[async_trait(?Send)]
100pub trait RandomizedTest: 'static + Sized {
101    type Operation: Send + Clone + Serialize + DeserializeOwned;
102
103    fn generate_operation(
104        client: &TestClient,
105        rng: &mut StdRng,
106        plan: &mut UserTestPlan,
107        cx: &TestAppContext,
108    ) -> Self::Operation;
109
110    async fn apply_operation(
111        client: &TestClient,
112        operation: Self::Operation,
113        cx: &mut TestAppContext,
114    ) -> Result<(), TestError>;
115
116    async fn initialize(server: &mut TestServer, users: &[UserTestPlan]);
117
118    async fn on_client_added(client: &Rc<TestClient>, cx: &mut TestAppContext);
119
120    async fn on_quiesce(server: &mut TestServer, client: &mut [(Rc<TestClient>, TestAppContext)]);
121}
122
123pub async fn run_randomized_test<T: RandomizedTest>(
124    cx: &mut TestAppContext,
125    deterministic: Arc<Deterministic>,
126    rng: StdRng,
127) {
128    deterministic.forbid_parking();
129    let mut server = TestServer::start(&deterministic).await;
130    let plan = TestPlan::<T>::new(&mut server, rng).await;
131
132    LAST_PLAN.lock().replace({
133        let plan = plan.clone();
134        Box::new(move || plan.lock().serialize())
135    });
136
137    let mut clients = Vec::new();
138    let mut client_tasks = Vec::new();
139    let mut operation_channels = Vec::new();
140    loop {
141        let Some((next_operation, applied)) = plan.lock().next_server_operation(&clients) else {
142            break;
143        };
144        applied.store(true, SeqCst);
145        let did_apply = TestPlan::apply_server_operation(
146            plan.clone(),
147            deterministic.clone(),
148            &mut server,
149            &mut clients,
150            &mut client_tasks,
151            &mut operation_channels,
152            next_operation,
153            cx,
154        )
155        .await;
156        if !did_apply {
157            applied.store(false, SeqCst);
158        }
159    }
160
161    drop(operation_channels);
162    deterministic.start_waiting();
163    futures::future::join_all(client_tasks).await;
164    deterministic.finish_waiting();
165
166    deterministic.run_until_parked();
167    T::on_quiesce(&mut server, &mut clients).await;
168
169    for (client, mut cx) in clients {
170        cx.update(|cx| {
171            let store = cx.remove_global::<SettingsStore>();
172            cx.clear_globals();
173            cx.set_global(store);
174            drop(client);
175        });
176    }
177    deterministic.run_until_parked();
178
179    if let Some(path) = &*PLAN_SAVE_PATH {
180        eprintln!("saved test plan to path {:?}", path);
181        std::fs::write(path, plan.lock().serialize()).unwrap();
182    }
183}
184
185pub fn save_randomized_test_plan() {
186    if let Some(serialize_plan) = LAST_PLAN.lock().take() {
187        if let Some(path) = &*PLAN_SAVE_PATH {
188            eprintln!("saved test plan to path {:?}", path);
189            std::fs::write(path, serialize_plan()).unwrap();
190        }
191    }
192}
193
194impl<T: RandomizedTest> TestPlan<T> {
195    pub async fn new(server: &mut TestServer, mut rng: StdRng) -> Arc<Mutex<Self>> {
196        let allow_server_restarts = rng.gen_bool(0.7);
197        let allow_client_reconnection = rng.gen_bool(0.7);
198        let allow_client_disconnection = rng.gen_bool(0.1);
199
200        let mut users = Vec::new();
201        for ix in 0..*MAX_PEERS {
202            let username = format!("user-{}", ix + 1);
203            let user_id = server
204                .app_state
205                .db
206                .create_user(
207                    &format!("{username}@example.com"),
208                    false,
209                    NewUserParams {
210                        github_login: username.clone(),
211                        github_user_id: (ix + 1) as i32,
212                        invite_count: 0,
213                    },
214                )
215                .await
216                .unwrap()
217                .user_id;
218            users.push(UserTestPlan {
219                user_id,
220                username,
221                online: false,
222                next_root_id: 0,
223                operation_ix: 0,
224                allow_client_disconnection,
225                allow_client_reconnection,
226            });
227        }
228
229        T::initialize(server, &users).await;
230
231        let plan = Arc::new(Mutex::new(Self {
232            replay: false,
233            allow_server_restarts,
234            allow_client_reconnection,
235            allow_client_disconnection,
236            stored_operations: Vec::new(),
237            operation_ix: 0,
238            next_batch_id: 0,
239            max_operations: *MAX_OPERATIONS,
240            users,
241            rng,
242        }));
243
244        if let Some(path) = &*PLAN_LOAD_PATH {
245            let json = LOADED_PLAN_JSON
246                .lock()
247                .get_or_insert_with(|| {
248                    eprintln!("loaded test plan from path {:?}", path);
249                    std::fs::read(path).unwrap()
250                })
251                .clone();
252            plan.lock().deserialize(json);
253        }
254
255        plan
256    }
257
258    fn deserialize(&mut self, json: Vec<u8>) {
259        let stored_operations: Vec<StoredOperation<T::Operation>> =
260            serde_json::from_slice(&json).unwrap();
261        self.replay = true;
262        self.stored_operations = stored_operations
263            .iter()
264            .cloned()
265            .enumerate()
266            .map(|(i, mut operation)| {
267                let did_apply = Arc::new(AtomicBool::new(false));
268                if let StoredOperation::Server(ServerOperation::MutateClients {
269                    batch_id: current_batch_id,
270                    user_ids,
271                    ..
272                }) = &mut operation
273                {
274                    assert!(user_ids.is_empty());
275                    user_ids.extend(stored_operations[i + 1..].iter().filter_map(|operation| {
276                        if let StoredOperation::Client {
277                            user_id, batch_id, ..
278                        } = operation
279                        {
280                            if batch_id == current_batch_id {
281                                return Some(user_id);
282                            }
283                        }
284                        None
285                    }));
286                    user_ids.sort_unstable();
287                }
288                (operation, did_apply)
289            })
290            .collect()
291    }
292
293    fn serialize(&mut self) -> Vec<u8> {
294        // Format each operation as one line
295        let mut json = Vec::new();
296        json.push(b'[');
297        for (operation, applied) in &self.stored_operations {
298            if !applied.load(SeqCst) {
299                continue;
300            }
301            if json.len() > 1 {
302                json.push(b',');
303            }
304            json.extend_from_slice(b"\n  ");
305            serde_json::to_writer(&mut json, operation).unwrap();
306        }
307        json.extend_from_slice(b"\n]\n");
308        json
309    }
310
311    fn next_server_operation(
312        &mut self,
313        clients: &[(Rc<TestClient>, TestAppContext)],
314    ) -> Option<(ServerOperation, Arc<AtomicBool>)> {
315        if self.replay {
316            while let Some(stored_operation) = self.stored_operations.get(self.operation_ix) {
317                self.operation_ix += 1;
318                if let (StoredOperation::Server(operation), applied) = stored_operation {
319                    return Some((operation.clone(), applied.clone()));
320                }
321            }
322            None
323        } else {
324            let operation = self.generate_server_operation(clients)?;
325            let applied = Arc::new(AtomicBool::new(false));
326            self.stored_operations
327                .push((StoredOperation::Server(operation.clone()), applied.clone()));
328            Some((operation, applied))
329        }
330    }
331
332    fn next_client_operation(
333        &mut self,
334        client: &TestClient,
335        current_batch_id: usize,
336        cx: &TestAppContext,
337    ) -> Option<(T::Operation, Arc<AtomicBool>)> {
338        let current_user_id = client.current_user_id(cx);
339        let user_ix = self
340            .users
341            .iter()
342            .position(|user| user.user_id == current_user_id)
343            .unwrap();
344        let user_plan = &mut self.users[user_ix];
345
346        if self.replay {
347            while let Some(stored_operation) = self.stored_operations.get(user_plan.operation_ix) {
348                user_plan.operation_ix += 1;
349                if let (
350                    StoredOperation::Client {
351                        user_id, operation, ..
352                    },
353                    applied,
354                ) = stored_operation
355                {
356                    if user_id == &current_user_id {
357                        return Some((operation.clone(), applied.clone()));
358                    }
359                }
360            }
361            None
362        } else {
363            if self.operation_ix == self.max_operations {
364                return None;
365            }
366            self.operation_ix += 1;
367            let operation = T::generate_operation(
368                client,
369                &mut self.rng,
370                self.users
371                    .iter_mut()
372                    .find(|user| user.user_id == current_user_id)
373                    .unwrap(),
374                cx,
375            );
376            let applied = Arc::new(AtomicBool::new(false));
377            self.stored_operations.push((
378                StoredOperation::Client {
379                    user_id: current_user_id,
380                    batch_id: current_batch_id,
381                    operation: operation.clone(),
382                },
383                applied.clone(),
384            ));
385            Some((operation, applied))
386        }
387    }
388
389    fn generate_server_operation(
390        &mut self,
391        clients: &[(Rc<TestClient>, TestAppContext)],
392    ) -> Option<ServerOperation> {
393        if self.operation_ix == self.max_operations {
394            return None;
395        }
396
397        Some(loop {
398            break match self.rng.gen_range(0..100) {
399                0..=29 if clients.len() < self.users.len() => {
400                    let user = self
401                        .users
402                        .iter()
403                        .filter(|u| !u.online)
404                        .choose(&mut self.rng)
405                        .unwrap();
406                    self.operation_ix += 1;
407                    ServerOperation::AddConnection {
408                        user_id: user.user_id,
409                    }
410                }
411                30..=34 if clients.len() > 1 && self.allow_client_disconnection => {
412                    let (client, cx) = &clients[self.rng.gen_range(0..clients.len())];
413                    let user_id = client.current_user_id(cx);
414                    self.operation_ix += 1;
415                    ServerOperation::RemoveConnection { user_id }
416                }
417                35..=39 if clients.len() > 1 && self.allow_client_reconnection => {
418                    let (client, cx) = &clients[self.rng.gen_range(0..clients.len())];
419                    let user_id = client.current_user_id(cx);
420                    self.operation_ix += 1;
421                    ServerOperation::BounceConnection { user_id }
422                }
423                40..=44 if self.allow_server_restarts && clients.len() > 1 => {
424                    self.operation_ix += 1;
425                    ServerOperation::RestartServer
426                }
427                _ if !clients.is_empty() => {
428                    let count = self
429                        .rng
430                        .gen_range(1..10)
431                        .min(self.max_operations - self.operation_ix);
432                    let batch_id = util::post_inc(&mut self.next_batch_id);
433                    let mut user_ids = (0..count)
434                        .map(|_| {
435                            let ix = self.rng.gen_range(0..clients.len());
436                            let (client, cx) = &clients[ix];
437                            client.current_user_id(cx)
438                        })
439                        .collect::<Vec<_>>();
440                    user_ids.sort_unstable();
441                    ServerOperation::MutateClients {
442                        user_ids,
443                        batch_id,
444                        quiesce: self.rng.gen_bool(0.7),
445                    }
446                }
447                _ => continue,
448            };
449        })
450    }
451
452    async fn apply_server_operation(
453        plan: Arc<Mutex<Self>>,
454        deterministic: Arc<Deterministic>,
455        server: &mut TestServer,
456        clients: &mut Vec<(Rc<TestClient>, TestAppContext)>,
457        client_tasks: &mut Vec<Task<()>>,
458        operation_channels: &mut Vec<futures::channel::mpsc::UnboundedSender<usize>>,
459        operation: ServerOperation,
460        cx: &mut TestAppContext,
461    ) -> bool {
462        match operation {
463            ServerOperation::AddConnection { user_id } => {
464                let username;
465                {
466                    let mut plan = plan.lock();
467                    let user = plan.user(user_id);
468                    if user.online {
469                        return false;
470                    }
471                    user.online = true;
472                    username = user.username.clone();
473                };
474                log::info!("adding new connection for {}", username);
475                let next_entity_id = (user_id.0 * 10_000) as usize;
476                let mut client_cx = TestAppContext::new(
477                    cx.foreground_platform(),
478                    cx.platform(),
479                    deterministic.build_foreground(user_id.0 as usize),
480                    deterministic.build_background(),
481                    cx.font_cache(),
482                    cx.leak_detector(),
483                    next_entity_id,
484                    cx.function_name.clone(),
485                );
486
487                let (operation_tx, operation_rx) = futures::channel::mpsc::unbounded();
488                let client = Rc::new(server.create_client(&mut client_cx, &username).await);
489                operation_channels.push(operation_tx);
490                clients.push((client.clone(), client_cx.clone()));
491                client_tasks.push(client_cx.foreground().spawn(Self::simulate_client(
492                    plan.clone(),
493                    client,
494                    operation_rx,
495                    client_cx,
496                )));
497
498                log::info!("added connection for {}", username);
499            }
500
501            ServerOperation::RemoveConnection {
502                user_id: removed_user_id,
503            } => {
504                log::info!("simulating full disconnection of user {}", removed_user_id);
505                let client_ix = clients
506                    .iter()
507                    .position(|(client, cx)| client.current_user_id(cx) == removed_user_id);
508                let Some(client_ix) = client_ix else {
509                    return false;
510                };
511                let user_connection_ids = server
512                    .connection_pool
513                    .lock()
514                    .user_connection_ids(removed_user_id)
515                    .collect::<Vec<_>>();
516                assert_eq!(user_connection_ids.len(), 1);
517                let removed_peer_id = user_connection_ids[0].into();
518                let (client, mut client_cx) = clients.remove(client_ix);
519                let client_task = client_tasks.remove(client_ix);
520                operation_channels.remove(client_ix);
521                server.forbid_connections();
522                server.disconnect_client(removed_peer_id);
523                deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
524                deterministic.start_waiting();
525                log::info!("waiting for user {} to exit...", removed_user_id);
526                client_task.await;
527                deterministic.finish_waiting();
528                server.allow_connections();
529
530                for project in client.remote_projects().iter() {
531                    project.read_with(&client_cx, |project, _| {
532                        assert!(
533                            project.is_read_only(),
534                            "project {:?} should be read only",
535                            project.remote_id()
536                        )
537                    });
538                }
539
540                for (client, cx) in clients {
541                    let contacts = server
542                        .app_state
543                        .db
544                        .get_contacts(client.current_user_id(cx))
545                        .await
546                        .unwrap();
547                    let pool = server.connection_pool.lock();
548                    for contact in contacts {
549                        if let db::Contact::Accepted { user_id, busy, .. } = contact {
550                            if user_id == removed_user_id {
551                                assert!(!pool.is_user_online(user_id));
552                                assert!(!busy);
553                            }
554                        }
555                    }
556                }
557
558                log::info!("{} removed", client.username);
559                plan.lock().user(removed_user_id).online = false;
560                client_cx.update(|cx| {
561                    cx.clear_globals();
562                    drop(client);
563                });
564            }
565
566            ServerOperation::BounceConnection { user_id } => {
567                log::info!("simulating temporary disconnection of user {}", user_id);
568                let user_connection_ids = server
569                    .connection_pool
570                    .lock()
571                    .user_connection_ids(user_id)
572                    .collect::<Vec<_>>();
573                if user_connection_ids.is_empty() {
574                    return false;
575                }
576                assert_eq!(user_connection_ids.len(), 1);
577                let peer_id = user_connection_ids[0].into();
578                server.disconnect_client(peer_id);
579                deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
580            }
581
582            ServerOperation::RestartServer => {
583                log::info!("simulating server restart");
584                server.reset().await;
585                deterministic.advance_clock(RECEIVE_TIMEOUT);
586                server.start().await.unwrap();
587                deterministic.advance_clock(CLEANUP_TIMEOUT);
588                let environment = &server.app_state.config.zed_environment;
589                let (stale_room_ids, _) = server
590                    .app_state
591                    .db
592                    .stale_server_resource_ids(environment, server.id())
593                    .await
594                    .unwrap();
595                assert_eq!(stale_room_ids, vec![]);
596            }
597
598            ServerOperation::MutateClients {
599                user_ids,
600                batch_id,
601                quiesce,
602            } => {
603                let mut applied = false;
604                for user_id in user_ids {
605                    let client_ix = clients
606                        .iter()
607                        .position(|(client, cx)| client.current_user_id(cx) == user_id);
608                    let Some(client_ix) = client_ix else { continue };
609                    applied = true;
610                    if let Err(err) = operation_channels[client_ix].unbounded_send(batch_id) {
611                        log::error!("error signaling user {user_id}: {err}");
612                    }
613                }
614
615                if quiesce && applied {
616                    deterministic.run_until_parked();
617                    T::on_quiesce(server, clients).await;
618                }
619
620                return applied;
621            }
622        }
623        true
624    }
625
626    async fn simulate_client(
627        plan: Arc<Mutex<Self>>,
628        client: Rc<TestClient>,
629        mut operation_rx: futures::channel::mpsc::UnboundedReceiver<usize>,
630        mut cx: TestAppContext,
631    ) {
632        T::on_client_added(&client, &mut cx).await;
633
634        while let Some(batch_id) = operation_rx.next().await {
635            let Some((operation, applied)) =
636                plan.lock().next_client_operation(&client, batch_id, &cx)
637            else {
638                break;
639            };
640            applied.store(true, SeqCst);
641            match T::apply_operation(&client, operation, &mut cx).await {
642                Ok(()) => {}
643                Err(TestError::Inapplicable) => {
644                    applied.store(false, SeqCst);
645                    log::info!("skipped operation");
646                }
647                Err(TestError::Other(error)) => {
648                    log::error!("{} error: {}", client.username, error);
649                }
650            }
651            cx.background().simulate_random_delay().await;
652        }
653        log::info!("{}: done", client.username);
654    }
655
656    fn user(&mut self, user_id: UserId) -> &mut UserTestPlan {
657        self.users
658            .iter_mut()
659            .find(|user| user.user_id == user_id)
660            .unwrap()
661    }
662}
663
664impl UserTestPlan {
665    pub fn next_root_dir_name(&mut self) -> String {
666        let user_id = self.user_id;
667        let root_id = util::post_inc(&mut self.next_root_id);
668        format!("dir-{user_id}-{root_id}")
669    }
670}
671
672impl From<anyhow::Error> for TestError {
673    fn from(value: anyhow::Error) -> Self {
674        Self::Other(value)
675    }
676}
677
678fn path_env_var(name: &str) -> Option<PathBuf> {
679    let value = env::var(name).ok()?;
680    let mut path = PathBuf::from(value);
681    if path.is_relative() {
682        let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
683        abs_path.pop();
684        abs_path.pop();
685        abs_path.push(path);
686        path = abs_path
687    }
688    Some(path)
689}