@@ -581,6 +581,7 @@ mod tests {
status.recv().await,
Some(Status::Connected { .. })
));
+ assert_eq!(server.auth_count(), 1);
server.forbid_connections();
server.disconnect().await;
@@ -589,6 +590,7 @@ mod tests {
server.allow_connections();
cx.foreground().advance_clock(Duration::from_secs(10));
while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
+ assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting
}
#[test]
@@ -21,7 +21,7 @@ use std::{
marker::PhantomData,
path::{Path, PathBuf},
sync::{
- atomic::{AtomicBool, Ordering::SeqCst},
+ atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
Arc,
},
};
@@ -209,6 +209,7 @@ pub struct FakeServer {
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
connection_id: Mutex<Option<ConnectionId>>,
forbid_connections: AtomicBool,
+ auth_count: AtomicUsize,
}
impl FakeServer {
@@ -217,26 +218,31 @@ impl FakeServer {
client: &mut Arc<Client>,
cx: &TestAppContext,
) -> Arc<Self> {
- let result = Arc::new(Self {
+ let server = Arc::new(Self {
peer: Peer::new(),
incoming: Default::default(),
connection_id: Default::default(),
forbid_connections: Default::default(),
+ auth_count: Default::default(),
});
Arc::get_mut(client)
.unwrap()
- .override_authenticate(move |cx| {
- cx.spawn(|_| async move {
- let access_token = "the-token".to_string();
- Ok(Credentials {
- user_id: client_user_id,
- access_token,
+ .override_authenticate({
+ let server = server.clone();
+ move |cx| {
+ server.auth_count.fetch_add(1, SeqCst);
+ cx.spawn(move |_| async move {
+ let access_token = "the-token".to_string();
+ Ok(Credentials {
+ user_id: client_user_id,
+ access_token,
+ })
})
- })
+ }
})
.override_establish_connection({
- let server = result.clone();
+ let server = server.clone();
move |credentials, cx| {
assert_eq!(credentials.user_id, client_user_id);
assert_eq!(credentials.access_token, "the-token");
@@ -251,7 +257,7 @@ impl FakeServer {
.authenticate_and_connect(&cx.to_async())
.await
.unwrap();
- result
+ server
}
pub async fn disconnect(&self) {
@@ -273,6 +279,10 @@ impl FakeServer {
}
}
+ pub fn auth_count(&self) -> usize {
+ self.auth_count.load(SeqCst)
+ }
+
pub fn forbid_connections(&self) {
self.forbid_connections.store(true, SeqCst);
}