randomized_test_helpers.rs

  1use crate::{
  2    db::{self, NewUserParams, UserId},
  3    rpc::{CLEANUP_TIMEOUT, RECONNECT_TIMEOUT},
  4    tests::{TestClient, TestServer},
  5};
  6use async_trait::async_trait;
  7use futures::StreamExt;
  8use gpui::{BackgroundExecutor, Task, TestAppContext};
  9use parking_lot::Mutex;
 10use rand::prelude::*;
 11use rpc::RECEIVE_TIMEOUT;
 12use serde::{de::DeserializeOwned, Deserialize, Serialize};
 13use settings::SettingsStore;
 14use std::{
 15    env,
 16    path::PathBuf,
 17    rc::Rc,
 18    sync::{
 19        atomic::{AtomicBool, Ordering::SeqCst},
 20        Arc,
 21    },
 22};
 23
 24lazy_static::lazy_static! {
 25    static ref PLAN_LOAD_PATH: Option<PathBuf> = path_env_var("LOAD_PLAN");
 26    static ref PLAN_SAVE_PATH: Option<PathBuf> = path_env_var("SAVE_PLAN");
 27    static ref MAX_PEERS: usize = env::var("MAX_PEERS")
 28        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
 29        .unwrap_or(3);
 30    static ref MAX_OPERATIONS: usize = env::var("OPERATIONS")
 31        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
 32        .unwrap_or(10);
 33
 34}
 35
 36static LOADED_PLAN_JSON: Mutex<Option<Vec<u8>>> = Mutex::new(None);
 37static LAST_PLAN: Mutex<Option<Box<dyn Send + FnOnce() -> Vec<u8>>>> = Mutex::new(None);
 38
 39struct TestPlan<T: RandomizedTest> {
 40    rng: StdRng,
 41    replay: bool,
 42    stored_operations: Vec<(StoredOperation<T::Operation>, Arc<AtomicBool>)>,
 43    max_operations: usize,
 44    operation_ix: usize,
 45    users: Vec<UserTestPlan>,
 46    next_batch_id: usize,
 47    allow_server_restarts: bool,
 48    allow_client_reconnection: bool,
 49    allow_client_disconnection: bool,
 50}
 51
 52pub struct UserTestPlan {
 53    pub user_id: UserId,
 54    pub username: String,
 55    pub allow_client_reconnection: bool,
 56    pub allow_client_disconnection: bool,
 57    next_root_id: usize,
 58    operation_ix: usize,
 59    online: bool,
 60}
 61
 62#[derive(Clone, Debug, Serialize, Deserialize)]
 63#[serde(untagged)]
 64enum StoredOperation<T> {
 65    Server(ServerOperation),
 66    Client {
 67        user_id: UserId,
 68        batch_id: usize,
 69        operation: T,
 70    },
 71}
 72
 73#[derive(Clone, Debug, Serialize, Deserialize)]
 74enum ServerOperation {
 75    AddConnection {
 76        user_id: UserId,
 77    },
 78    RemoveConnection {
 79        user_id: UserId,
 80    },
 81    BounceConnection {
 82        user_id: UserId,
 83    },
 84    RestartServer,
 85    MutateClients {
 86        batch_id: usize,
 87        #[serde(skip_serializing)]
 88        #[serde(skip_deserializing)]
 89        user_ids: Vec<UserId>,
 90        quiesce: bool,
 91    },
 92}
 93
 94pub enum TestError {
 95    Inapplicable,
 96    Other(anyhow::Error),
 97}
 98
 99#[async_trait(?Send)]
100pub trait RandomizedTest: 'static + Sized {
101    type Operation: Send + Clone + Serialize + DeserializeOwned;
102
103    fn generate_operation(
104        client: &TestClient,
105        rng: &mut StdRng,
106        plan: &mut UserTestPlan,
107        cx: &TestAppContext,
108    ) -> Self::Operation;
109
110    async fn apply_operation(
111        client: &TestClient,
112        operation: Self::Operation,
113        cx: &mut TestAppContext,
114    ) -> Result<(), TestError>;
115
116    async fn initialize(server: &mut TestServer, users: &[UserTestPlan]);
117
118    async fn on_client_added(_client: &Rc<TestClient>, _cx: &mut TestAppContext) {}
119
120    async fn on_quiesce(server: &mut TestServer, client: &mut [(Rc<TestClient>, TestAppContext)]);
121}
122
123pub async fn run_randomized_test<T: RandomizedTest>(
124    cx: &mut TestAppContext,
125    executor: BackgroundExecutor,
126    rng: StdRng,
127) {
128    let mut server = TestServer::start(executor.clone()).await;
129    let plan = TestPlan::<T>::new(&mut server, rng).await;
130
131    LAST_PLAN.lock().replace({
132        let plan = plan.clone();
133        Box::new(move || plan.lock().serialize())
134    });
135
136    let mut clients = Vec::new();
137    let mut client_tasks = Vec::new();
138    let mut operation_channels = Vec::new();
139    loop {
140        let Some((next_operation, applied)) = plan.lock().next_server_operation(&clients) else {
141            break;
142        };
143        applied.store(true, SeqCst);
144        let did_apply = TestPlan::apply_server_operation(
145            plan.clone(),
146            executor.clone(),
147            &mut server,
148            &mut clients,
149            &mut client_tasks,
150            &mut operation_channels,
151            next_operation,
152            cx,
153        )
154        .await;
155        if !did_apply {
156            applied.store(false, SeqCst);
157        }
158    }
159
160    drop(operation_channels);
161    executor.start_waiting();
162    futures::future::join_all(client_tasks).await;
163    executor.finish_waiting();
164
165    executor.run_until_parked();
166    T::on_quiesce(&mut server, &mut clients).await;
167
168    for (client, cx) in clients {
169        cx.update(|cx| {
170            let store = cx.remove_global::<SettingsStore>();
171            cx.clear_globals();
172            cx.set_global(store);
173            drop(client);
174        });
175    }
176    executor.run_until_parked();
177
178    if let Some(path) = &*PLAN_SAVE_PATH {
179        eprintln!("saved test plan to path {:?}", path);
180        std::fs::write(path, plan.lock().serialize()).unwrap();
181    }
182}
183
184pub fn save_randomized_test_plan() {
185    if let Some(serialize_plan) = LAST_PLAN.lock().take() {
186        if let Some(path) = &*PLAN_SAVE_PATH {
187            eprintln!("saved test plan to path {:?}", path);
188            std::fs::write(path, serialize_plan()).unwrap();
189        }
190    }
191}
192
193impl<T: RandomizedTest> TestPlan<T> {
194    pub async fn new(server: &mut TestServer, mut rng: StdRng) -> Arc<Mutex<Self>> {
195        let allow_server_restarts = rng.gen_bool(0.7);
196        let allow_client_reconnection = rng.gen_bool(0.7);
197        let allow_client_disconnection = rng.gen_bool(0.1);
198
199        let mut users = Vec::new();
200        for ix in 0..*MAX_PEERS {
201            let username = format!("user-{}", ix + 1);
202            let user_id = server
203                .app_state
204                .db
205                .create_user(
206                    &format!("{username}@example.com"),
207                    false,
208                    NewUserParams {
209                        github_login: username.clone(),
210                        github_user_id: ix as i32,
211                    },
212                )
213                .await
214                .unwrap()
215                .user_id;
216            users.push(UserTestPlan {
217                user_id,
218                username,
219                online: false,
220                next_root_id: 0,
221                operation_ix: 0,
222                allow_client_disconnection,
223                allow_client_reconnection,
224            });
225        }
226
227        T::initialize(server, &users).await;
228
229        let plan = Arc::new(Mutex::new(Self {
230            replay: false,
231            allow_server_restarts,
232            allow_client_reconnection,
233            allow_client_disconnection,
234            stored_operations: Vec::new(),
235            operation_ix: 0,
236            next_batch_id: 0,
237            max_operations: *MAX_OPERATIONS,
238            users,
239            rng,
240        }));
241
242        if let Some(path) = &*PLAN_LOAD_PATH {
243            let json = LOADED_PLAN_JSON
244                .lock()
245                .get_or_insert_with(|| {
246                    eprintln!("loaded test plan from path {:?}", path);
247                    std::fs::read(path).unwrap()
248                })
249                .clone();
250            plan.lock().deserialize(json);
251        }
252
253        plan
254    }
255
256    fn deserialize(&mut self, json: Vec<u8>) {
257        let stored_operations: Vec<StoredOperation<T::Operation>> =
258            serde_json::from_slice(&json).unwrap();
259        self.replay = true;
260        self.stored_operations = stored_operations
261            .iter()
262            .cloned()
263            .enumerate()
264            .map(|(i, mut operation)| {
265                let did_apply = Arc::new(AtomicBool::new(false));
266                if let StoredOperation::Server(ServerOperation::MutateClients {
267                    batch_id: current_batch_id,
268                    user_ids,
269                    ..
270                }) = &mut operation
271                {
272                    assert!(user_ids.is_empty());
273                    user_ids.extend(stored_operations[i + 1..].iter().filter_map(|operation| {
274                        if let StoredOperation::Client {
275                            user_id, batch_id, ..
276                        } = operation
277                        {
278                            if batch_id == current_batch_id {
279                                return Some(user_id);
280                            }
281                        }
282                        None
283                    }));
284                    user_ids.sort_unstable();
285                }
286                (operation, did_apply)
287            })
288            .collect()
289    }
290
291    fn serialize(&mut self) -> Vec<u8> {
292        // Format each operation as one line
293        let mut json = Vec::new();
294        json.push(b'[');
295        for (operation, applied) in &self.stored_operations {
296            if !applied.load(SeqCst) {
297                continue;
298            }
299            if json.len() > 1 {
300                json.push(b',');
301            }
302            json.extend_from_slice(b"\n  ");
303            serde_json::to_writer(&mut json, operation).unwrap();
304        }
305        json.extend_from_slice(b"\n]\n");
306        json
307    }
308
309    fn next_server_operation(
310        &mut self,
311        clients: &[(Rc<TestClient>, TestAppContext)],
312    ) -> Option<(ServerOperation, Arc<AtomicBool>)> {
313        if self.replay {
314            while let Some(stored_operation) = self.stored_operations.get(self.operation_ix) {
315                self.operation_ix += 1;
316                if let (StoredOperation::Server(operation), applied) = stored_operation {
317                    return Some((operation.clone(), applied.clone()));
318                }
319            }
320            None
321        } else {
322            let operation = self.generate_server_operation(clients)?;
323            let applied = Arc::new(AtomicBool::new(false));
324            self.stored_operations
325                .push((StoredOperation::Server(operation.clone()), applied.clone()));
326            Some((operation, applied))
327        }
328    }
329
330    fn next_client_operation(
331        &mut self,
332        client: &TestClient,
333        current_batch_id: usize,
334        cx: &TestAppContext,
335    ) -> Option<(T::Operation, Arc<AtomicBool>)> {
336        let current_user_id = client.current_user_id(cx);
337        let user_ix = self
338            .users
339            .iter()
340            .position(|user| user.user_id == current_user_id)
341            .unwrap();
342        let user_plan = &mut self.users[user_ix];
343
344        if self.replay {
345            while let Some(stored_operation) = self.stored_operations.get(user_plan.operation_ix) {
346                user_plan.operation_ix += 1;
347                if let (
348                    StoredOperation::Client {
349                        user_id, operation, ..
350                    },
351                    applied,
352                ) = stored_operation
353                {
354                    if user_id == &current_user_id {
355                        return Some((operation.clone(), applied.clone()));
356                    }
357                }
358            }
359            None
360        } else {
361            if self.operation_ix == self.max_operations {
362                return None;
363            }
364            self.operation_ix += 1;
365            let operation = T::generate_operation(
366                client,
367                &mut self.rng,
368                self.users
369                    .iter_mut()
370                    .find(|user| user.user_id == current_user_id)
371                    .unwrap(),
372                cx,
373            );
374            let applied = Arc::new(AtomicBool::new(false));
375            self.stored_operations.push((
376                StoredOperation::Client {
377                    user_id: current_user_id,
378                    batch_id: current_batch_id,
379                    operation: operation.clone(),
380                },
381                applied.clone(),
382            ));
383            Some((operation, applied))
384        }
385    }
386
387    fn generate_server_operation(
388        &mut self,
389        clients: &[(Rc<TestClient>, TestAppContext)],
390    ) -> Option<ServerOperation> {
391        if self.operation_ix == self.max_operations {
392            return None;
393        }
394
395        Some(loop {
396            break match self.rng.gen_range(0..100) {
397                0..=29 if clients.len() < self.users.len() => {
398                    let user = self
399                        .users
400                        .iter()
401                        .filter(|u| !u.online)
402                        .choose(&mut self.rng)
403                        .unwrap();
404                    self.operation_ix += 1;
405                    ServerOperation::AddConnection {
406                        user_id: user.user_id,
407                    }
408                }
409                30..=34 if clients.len() > 1 && self.allow_client_disconnection => {
410                    let (client, cx) = &clients[self.rng.gen_range(0..clients.len())];
411                    let user_id = client.current_user_id(cx);
412                    self.operation_ix += 1;
413                    ServerOperation::RemoveConnection { user_id }
414                }
415                35..=39 if clients.len() > 1 && self.allow_client_reconnection => {
416                    let (client, cx) = &clients[self.rng.gen_range(0..clients.len())];
417                    let user_id = client.current_user_id(cx);
418                    self.operation_ix += 1;
419                    ServerOperation::BounceConnection { user_id }
420                }
421                40..=44 if self.allow_server_restarts && clients.len() > 1 => {
422                    self.operation_ix += 1;
423                    ServerOperation::RestartServer
424                }
425                _ if !clients.is_empty() => {
426                    let count = self
427                        .rng
428                        .gen_range(1..10)
429                        .min(self.max_operations - self.operation_ix);
430                    let batch_id = util::post_inc(&mut self.next_batch_id);
431                    let mut user_ids = (0..count)
432                        .map(|_| {
433                            let ix = self.rng.gen_range(0..clients.len());
434                            let (client, cx) = &clients[ix];
435                            client.current_user_id(cx)
436                        })
437                        .collect::<Vec<_>>();
438                    user_ids.sort_unstable();
439                    ServerOperation::MutateClients {
440                        user_ids,
441                        batch_id,
442                        quiesce: self.rng.gen_bool(0.7),
443                    }
444                }
445                _ => continue,
446            };
447        })
448    }
449
450    async fn apply_server_operation(
451        plan: Arc<Mutex<Self>>,
452        deterministic: BackgroundExecutor,
453        server: &mut TestServer,
454        clients: &mut Vec<(Rc<TestClient>, TestAppContext)>,
455        client_tasks: &mut Vec<Task<()>>,
456        operation_channels: &mut Vec<futures::channel::mpsc::UnboundedSender<usize>>,
457        operation: ServerOperation,
458        cx: &mut TestAppContext,
459    ) -> bool {
460        match operation {
461            ServerOperation::AddConnection { user_id } => {
462                let username;
463                {
464                    let mut plan = plan.lock();
465                    let user = plan.user(user_id);
466                    if user.online {
467                        return false;
468                    }
469                    user.online = true;
470                    username = user.username.clone();
471                };
472                log::info!("adding new connection for {}", username);
473
474                let mut client_cx = cx.new_app();
475
476                let (operation_tx, operation_rx) = futures::channel::mpsc::unbounded();
477                let client = Rc::new(server.create_client(&mut client_cx, &username).await);
478                operation_channels.push(operation_tx);
479                clients.push((client.clone(), client_cx.clone()));
480
481                let foreground_executor = client_cx.foreground_executor().clone();
482                let simulate_client =
483                    Self::simulate_client(plan.clone(), client, operation_rx, client_cx);
484                client_tasks.push(foreground_executor.spawn(simulate_client));
485
486                log::info!("added connection for {}", username);
487            }
488
489            ServerOperation::RemoveConnection {
490                user_id: removed_user_id,
491            } => {
492                log::info!("simulating full disconnection of user {}", removed_user_id);
493                let client_ix = clients
494                    .iter()
495                    .position(|(client, cx)| client.current_user_id(cx) == removed_user_id);
496                let Some(client_ix) = client_ix else {
497                    return false;
498                };
499                let user_connection_ids = server
500                    .connection_pool
501                    .lock()
502                    .user_connection_ids(removed_user_id)
503                    .collect::<Vec<_>>();
504                assert_eq!(user_connection_ids.len(), 1);
505                let removed_peer_id = user_connection_ids[0].into();
506                let (client, client_cx) = clients.remove(client_ix);
507                let client_task = client_tasks.remove(client_ix);
508                operation_channels.remove(client_ix);
509                server.forbid_connections();
510                server.disconnect_client(removed_peer_id);
511                deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
512                deterministic.start_waiting();
513                log::info!("waiting for user {} to exit...", removed_user_id);
514                client_task.await;
515                deterministic.finish_waiting();
516                server.allow_connections();
517
518                for project in client.remote_projects().iter() {
519                    project.read_with(&client_cx, |project, _| {
520                        assert!(
521                            project.is_disconnected(),
522                            "project {:?} should be read only",
523                            project.remote_id()
524                        )
525                    });
526                }
527
528                for (client, cx) in clients {
529                    let contacts = server
530                        .app_state
531                        .db
532                        .get_contacts(client.current_user_id(cx))
533                        .await
534                        .unwrap();
535                    let pool = server.connection_pool.lock();
536                    for contact in contacts {
537                        if let db::Contact::Accepted { user_id, busy, .. } = contact {
538                            if user_id == removed_user_id {
539                                assert!(!pool.is_user_online(user_id));
540                                assert!(!busy);
541                            }
542                        }
543                    }
544                }
545
546                log::info!("{} removed", client.username);
547                plan.lock().user(removed_user_id).online = false;
548                client_cx.update(|cx| {
549                    cx.clear_globals();
550                    drop(client);
551                });
552            }
553
554            ServerOperation::BounceConnection { user_id } => {
555                log::info!("simulating temporary disconnection of user {}", user_id);
556                let user_connection_ids = server
557                    .connection_pool
558                    .lock()
559                    .user_connection_ids(user_id)
560                    .collect::<Vec<_>>();
561                if user_connection_ids.is_empty() {
562                    return false;
563                }
564                assert_eq!(user_connection_ids.len(), 1);
565                let peer_id = user_connection_ids[0].into();
566                server.disconnect_client(peer_id);
567                deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
568            }
569
570            ServerOperation::RestartServer => {
571                log::info!("simulating server restart");
572                server.reset().await;
573                deterministic.advance_clock(RECEIVE_TIMEOUT);
574                server.start().await.unwrap();
575                deterministic.advance_clock(CLEANUP_TIMEOUT);
576                let environment = &server.app_state.config.zed_environment;
577                let (stale_room_ids, _) = server
578                    .app_state
579                    .db
580                    .stale_server_resource_ids(environment, server.id())
581                    .await
582                    .unwrap();
583                assert_eq!(stale_room_ids, vec![]);
584            }
585
586            ServerOperation::MutateClients {
587                user_ids,
588                batch_id,
589                quiesce,
590            } => {
591                let mut applied = false;
592                for user_id in user_ids {
593                    let client_ix = clients
594                        .iter()
595                        .position(|(client, cx)| client.current_user_id(cx) == user_id);
596                    let Some(client_ix) = client_ix else { continue };
597                    applied = true;
598                    if let Err(err) = operation_channels[client_ix].unbounded_send(batch_id) {
599                        log::error!("error signaling user {user_id}: {err}");
600                    }
601                }
602
603                if quiesce && applied {
604                    deterministic.run_until_parked();
605                    T::on_quiesce(server, clients).await;
606                }
607
608                return applied;
609            }
610        }
611        true
612    }
613
614    async fn simulate_client(
615        plan: Arc<Mutex<Self>>,
616        client: Rc<TestClient>,
617        mut operation_rx: futures::channel::mpsc::UnboundedReceiver<usize>,
618        mut cx: TestAppContext,
619    ) {
620        T::on_client_added(&client, &mut cx).await;
621
622        while let Some(batch_id) = operation_rx.next().await {
623            let Some((operation, applied)) =
624                plan.lock().next_client_operation(&client, batch_id, &cx)
625            else {
626                break;
627            };
628            applied.store(true, SeqCst);
629            match T::apply_operation(&client, operation, &mut cx).await {
630                Ok(()) => {}
631                Err(TestError::Inapplicable) => {
632                    applied.store(false, SeqCst);
633                    log::info!("skipped operation");
634                }
635                Err(TestError::Other(error)) => {
636                    log::error!("{} error: {}", client.username, error);
637                }
638            }
639            cx.executor().simulate_random_delay().await;
640        }
641        log::info!("{}: done", client.username);
642    }
643
644    fn user(&mut self, user_id: UserId) -> &mut UserTestPlan {
645        self.users
646            .iter_mut()
647            .find(|user| user.user_id == user_id)
648            .unwrap()
649    }
650}
651
652impl UserTestPlan {
653    pub fn next_root_dir_name(&mut self) -> String {
654        let user_id = self.user_id;
655        let root_id = util::post_inc(&mut self.next_root_id);
656        format!("dir-{user_id}-{root_id}")
657    }
658}
659
660impl From<anyhow::Error> for TestError {
661    fn from(value: anyhow::Error) -> Self {
662        Self::Other(value)
663    }
664}
665
666fn path_env_var(name: &str) -> Option<PathBuf> {
667    let value = env::var(name).ok()?;
668    let mut path = PathBuf::from(value);
669    if path.is_relative() {
670        let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
671        abs_path.pop();
672        abs_path.pop();
673        abs_path.push(path);
674        path = abs_path
675    }
676    Some(path)
677}