@@ -1693,20 +1693,46 @@ mod tests {
cx: &mut TestAppContext,
name: &str,
) -> (UserId, Arc<Client>) {
- let user_id = self.app_state.db.create_user(name, false).await.unwrap();
- let client = Client::new();
- let (client_conn, server_conn) = Conn::in_memory();
- cx.background()
- .spawn(
- self.server
- .handle_connection(server_conn, name.to_string(), user_id),
- )
- .detach();
+ let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
+ let client_name = name.to_string();
+ let mut client = Client::new();
+ let server = self.server.clone();
+ Arc::get_mut(&mut client)
+ .unwrap()
+ .set_login_and_connect_callbacks(
+ move |cx| {
+ cx.spawn(|_| async move {
+ let access_token = "the-token".to_string();
+ Ok((client_user_id.0 as u64, access_token))
+ })
+ },
+ {
+ move |user_id, access_token, cx| {
+ assert_eq!(user_id, client_user_id.0 as u64);
+ assert_eq!(access_token, "the-token");
+
+ let server = server.clone();
+ let client_name = client_name.clone();
+ cx.spawn(move |cx| async move {
+ let (client_conn, server_conn) = Conn::in_memory();
+ cx.background()
+ .spawn(server.handle_connection(
+ server_conn,
+ client_name,
+ client_user_id,
+ ))
+ .detach();
+ Ok(client_conn)
+ })
+ }
+ },
+ );
+
client
- .set_connection(user_id.to_proto(), client_conn, &cx.to_async())
+ .authenticate_and_connect(&cx.to_async())
.await
.unwrap();
- (user_id, client)
+ (client_user_id, client)
}
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
@@ -49,7 +49,10 @@ pub enum Status {
user_id: u64,
},
ConnectionLost,
- Reconnecting,
+ Reauthenticating,
+ Reconnecting {
+ user_id: u64,
+ },
ReconnectionError {
next_reconnection: Instant,
},
@@ -164,9 +167,10 @@ impl Client {
}
}));
}
- _ => {
+ Status::Disconnected => {
state._maintain_connection.take();
}
+ _ => {}
}
}
@@ -227,14 +231,20 @@ impl Client {
self: &Arc<Self>,
cx: &AsyncAppContext,
) -> anyhow::Result<()> {
- if matches!(
- *self.status().borrow(),
- Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. }
- ) {
- return Ok(());
- }
+ let was_disconnected = match *self.status().borrow() {
+ Status::Disconnected => true,
+ Status::Connected { .. }
+ | Status::Connecting { .. }
+ | Status::Reconnecting { .. }
+ | Status::Reauthenticating => return Ok(()),
+ _ => false,
+ };
- self.set_status(Status::Authenticating, cx);
+ if was_disconnected {
+ self.set_status(Status::Authenticating, cx);
+ } else {
+ self.set_status(Status::Reauthenticating, cx)
+ }
let (user_id, access_token) = match self.authenticate(&cx).await {
Ok(result) => result,
@@ -244,27 +254,25 @@ impl Client {
}
};
- self.set_status(Status::Connecting { user_id }, cx);
-
- let conn = match self.connect(user_id, &access_token, cx).await {
- Ok(conn) => conn,
+ if was_disconnected {
+ self.set_status(Status::Connecting { user_id }, cx);
+ } else {
+ self.set_status(Status::Reconnecting { user_id }, cx);
+ }
+ match self.connect(user_id, &access_token, cx).await {
+ Ok(conn) => {
+ log::info!("connected to rpc address {}", *ZED_SERVER_URL);
+ self.set_connection(user_id, conn, cx).await;
+ Ok(())
+ }
Err(err) => {
self.set_status(Status::ConnectionError, cx);
- return Err(err);
+ Err(err)
}
- };
-
- self.set_connection(user_id, conn, cx).await?;
- log::info!("connected to rpc address {}", *ZED_SERVER_URL);
- Ok(())
+ }
}
- pub async fn set_connection(
- self: &Arc<Self>,
- user_id: u64,
- conn: Conn,
- cx: &AsyncAppContext,
- ) -> Result<()> {
+ async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
cx.foreground()
.spawn({
@@ -321,7 +329,6 @@ impl Client {
}
})
.detach();
- Ok(())
}
fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
@@ -489,35 +496,6 @@ impl Client {
}
}
-pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
- type Output: 'a + Future<Output = anyhow::Result<()>>;
-
- fn handle(
- &self,
- message: TypedEnvelope<M>,
- rpc: &'a Client,
- cx: &'a mut gpui::AsyncAppContext,
- ) -> Self::Output;
-}
-
-impl<'a, M, F, Fut> MessageHandler<'a, M> for F
-where
- M: proto::EnvelopedMessage,
- F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
- Fut: 'a + Future<Output = anyhow::Result<()>>,
-{
- type Output = Fut;
-
- fn handle(
- &self,
- message: TypedEnvelope<M>,
- rpc: &'a Client,
- cx: &'a mut gpui::AsyncAppContext,
- ) -> Self::Output {
- (self)(message, rpc, cx)
- }
-}
-
const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
@@ -550,6 +528,8 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_heartbeat(cx: TestAppContext) {
+ cx.foreground().forbid_parking();
+
let user_id = 5;
let mut client = Client::new();
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
@@ -568,6 +548,28 @@ mod tests {
assert!(server.receive::<proto::Ping>().await.is_err());
}
+ #[gpui::test(iterations = 10)]
+ async fn test_reconnection(cx: TestAppContext) {
+ cx.foreground().forbid_parking();
+
+ let user_id = 5;
+ let mut client = Client::new();
+ let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+ let mut status = client.status();
+ assert!(matches!(
+ status.recv().await,
+ Some(Status::Connected { .. })
+ ));
+
+ server.forbid_connections();
+ server.disconnect().await;
+ while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
+
+ server.allow_connections();
+ cx.foreground().advance_clock(Duration::from_secs(10));
+ while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
+ }
+
#[test]
fn test_encode_and_decode_worktree_url() {
let url = encode_worktree_url(5, "deadbeef");
@@ -17,7 +17,10 @@ use smol::channel;
use std::{
marker::PhantomData,
path::{Path, PathBuf},
- sync::Arc,
+ sync::{
+ atomic::{AtomicBool, Ordering::SeqCst},
+ Arc,
+ },
};
use tempdir::TempDir;
use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
@@ -200,6 +203,7 @@ pub struct FakeServer {
peer: Arc<Peer>,
incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
connection_id: Mutex<Option<ConnectionId>>,
+ forbid_connections: AtomicBool,
}
impl FakeServer {
@@ -212,6 +216,7 @@ impl FakeServer {
peer: Peer::new(),
incoming: Default::default(),
connection_id: Default::default(),
+ forbid_connections: Default::default(),
});
Arc::get_mut(client)
@@ -230,15 +235,14 @@ impl FakeServer {
assert_eq!(access_token, "the-token");
cx.spawn({
let server = server.clone();
- move |cx| async move { Ok(server.connect(&cx).await) }
+ move |cx| async move { server.connect(&cx).await }
})
}
},
);
- let conn = result.connect(&cx.to_async()).await;
client
- .set_connection(client_user_id, conn, &cx.to_async())
+ .authenticate_and_connect(&cx.to_async())
.await
.unwrap();
result
@@ -250,13 +254,25 @@ impl FakeServer {
self.incoming.lock().take();
}
- async fn connect(&self, cx: &AsyncAppContext) -> Conn {
- let (client_conn, server_conn) = Conn::in_memory();
- let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
- cx.background().spawn(io).detach();
- *self.incoming.lock() = Some(incoming);
- *self.connection_id.lock() = Some(connection_id);
- client_conn
+ async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
+ if self.forbid_connections.load(SeqCst) {
+ Err(anyhow!("server is forbidding connections"))
+ } else {
+ let (client_conn, server_conn) = Conn::in_memory();
+ let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
+ cx.background().spawn(io).detach();
+ *self.incoming.lock() = Some(incoming);
+ *self.connection_id.lock() = Some(connection_id);
+ Ok(client_conn)
+ }
+ }
+
+ pub fn forbid_connections(&self) {
+ self.forbid_connections.store(true, SeqCst);
+ }
+
+ pub fn allow_connections(&self) {
+ self.forbid_connections.store(false, SeqCst);
}
pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {