@@ -9,7 +9,11 @@ use std::{
mem,
pin::Pin,
rc::Rc,
- sync::{mpsc::SyncSender, Arc},
+ sync::{
+ atomic::{AtomicBool, Ordering::SeqCst},
+ mpsc::SyncSender,
+ Arc,
+ },
thread,
};
@@ -33,24 +37,25 @@ pub enum Background {
},
}
-#[derive(Default)]
-struct Runnables {
+struct DeterministicState {
+ rng: StdRng,
+ seed: u64,
scheduled: Vec<Runnable>,
spawned_from_foreground: Vec<Runnable>,
waker: Option<SyncSender<()>>,
}
-pub struct Deterministic {
- seed: u64,
- runnables: Arc<Mutex<Runnables>>,
-}
+pub struct Deterministic(Arc<Mutex<DeterministicState>>);
impl Deterministic {
fn new(seed: u64) -> Self {
- Self {
+ Self(Arc::new(Mutex::new(DeterministicState {
+ rng: StdRng::seed_from_u64(seed),
seed,
- runnables: Default::default(),
- }
+ scheduled: Default::default(),
+ spawned_from_foreground: Default::default(),
+ waker: None,
+ })))
}
pub fn spawn_from_foreground<F, T>(&self, future: F) -> Task<T>
@@ -58,17 +63,16 @@ impl Deterministic {
T: 'static,
F: Future<Output = T> + 'static,
{
- let scheduled_once = Mutex::new(false);
- let runnables = self.runnables.clone();
+ let scheduled_once = AtomicBool::new(false);
+ let state = self.0.clone();
let (runnable, task) = async_task::spawn_local(future, move |runnable| {
- let mut runnables = runnables.lock();
- if *scheduled_once.lock() {
- runnables.scheduled.push(runnable);
+ let mut state = state.lock();
+ if scheduled_once.fetch_or(true, SeqCst) {
+ state.scheduled.push(runnable);
} else {
- runnables.spawned_from_foreground.push(runnable);
- *scheduled_once.lock() = true;
+ state.spawned_from_foreground.push(runnable);
}
- if let Some(waker) = runnables.waker.as_ref() {
+ if let Some(waker) = state.waker.as_ref() {
waker.send(()).ok();
}
});
@@ -81,11 +85,11 @@ impl Deterministic {
T: 'static + Send,
F: 'static + Send + Future<Output = T>,
{
- let runnables = self.runnables.clone();
+ let state = self.0.clone();
let (runnable, task) = async_task::spawn(future, move |runnable| {
- let mut runnables = runnables.lock();
- runnables.scheduled.push(runnable);
- if let Some(waker) = runnables.waker.as_ref() {
+ let mut state = state.lock();
+ state.scheduled.push(runnable);
+ if let Some(waker) = state.waker.as_ref() {
waker.send(()).ok();
}
});
@@ -99,8 +103,8 @@ impl Deterministic {
F: Future<Output = T> + 'static,
{
let (wake_tx, wake_rx) = std::sync::mpsc::sync_channel(32);
- let runnables = self.runnables.clone();
- runnables.lock().waker = Some(wake_tx);
+ let state = self.0.clone();
+ state.lock().waker = Some(wake_tx);
let (output_tx, output_rx) = std::sync::mpsc::channel();
self.spawn_from_foreground(async move {
@@ -109,23 +113,22 @@ impl Deterministic {
})
.detach();
- let mut rng = StdRng::seed_from_u64(self.seed);
loop {
if let Ok(value) = output_rx.try_recv() {
- runnables.lock().waker = None;
+ state.lock().waker = None;
return value;
}
wake_rx.recv().unwrap();
let runnable = {
- let mut runnables = runnables.lock();
- let ix = rng.gen_range(
- 0..runnables.scheduled.len() + runnables.spawned_from_foreground.len(),
- );
- if ix < runnables.scheduled.len() {
- runnables.scheduled.remove(ix)
+ let state = &mut *state.lock();
+ let ix = state
+ .rng
+ .gen_range(0..state.scheduled.len() + state.spawned_from_foreground.len());
+ if ix < state.scheduled.len() {
+ state.scheduled.remove(ix)
} else {
- runnables.spawned_from_foreground.remove(0)
+ state.spawned_from_foreground.remove(0)
}
};
@@ -171,6 +174,17 @@ impl Foreground {
Self::Deterministic(executor) => executor.run(future),
}
}
+
+ pub fn reset(&self) {
+ match self {
+ Self::Platform { .. } => panic!("can't call this method on a platform executor"),
+ Self::Test(_) => panic!("can't call this method on a test executor"),
+ Self::Deterministic(executor) => {
+ let state = &mut *executor.0.lock();
+ state.rng = StdRng::seed_from_u64(state.seed);
+ }
+ }
+ }
}
impl Background {