diff --git a/server/src/rpc.rs b/server/src/rpc.rs index c2349674444bf7f758e9028d85cf7c64590e80cd..f623f2964907e8a031b3b4f99f24d0d844c51738 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1693,20 +1693,46 @@ mod tests { cx: &mut TestAppContext, name: &str, ) -> (UserId, Arc) { - let user_id = self.app_state.db.create_user(name, false).await.unwrap(); - let client = Client::new(); - let (client_conn, server_conn) = Conn::in_memory(); - cx.background() - .spawn( - self.server - .handle_connection(server_conn, name.to_string(), user_id), - ) - .detach(); + let client_user_id = self.app_state.db.create_user(name, false).await.unwrap(); + let client_name = name.to_string(); + let mut client = Client::new(); + let server = self.server.clone(); + Arc::get_mut(&mut client) + .unwrap() + .set_login_and_connect_callbacks( + move |cx| { + cx.spawn(|_| async move { + let access_token = "the-token".to_string(); + Ok((client_user_id.0 as u64, access_token)) + }) + }, + { + move |user_id, access_token, cx| { + assert_eq!(user_id, client_user_id.0 as u64); + assert_eq!(access_token, "the-token"); + + let server = server.clone(); + let client_name = client_name.clone(); + cx.spawn(move |cx| async move { + let (client_conn, server_conn) = Conn::in_memory(); + cx.background() + .spawn(server.handle_connection( + server_conn, + client_name, + client_user_id, + )) + .detach(); + Ok(client_conn) + }) + } + }, + ); + client - .set_connection(user_id.to_proto(), client_conn, &cx.to_async()) + .authenticate_and_connect(&cx.to_async()) .await .unwrap(); - (user_id, client) + (client_user_id, client) } async fn build_app_state(test_db: &TestDb) -> Arc { diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index 518e29170466eff91375b5e4d42737e90f33537e..3399c1cf41b80e33b55c7b413123fbadd72d3e1c 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -49,7 +49,10 @@ pub enum Status { user_id: u64, }, ConnectionLost, - Reconnecting, + Reauthenticating, + Reconnecting { + user_id: u64, + }, ReconnectionError { next_reconnection: Instant, }, @@ -164,9 +167,10 @@ impl Client { } })); } - _ => { + Status::Disconnected => { state._maintain_connection.take(); } + _ => {} } } @@ -227,14 +231,20 @@ impl Client { self: &Arc, cx: &AsyncAppContext, ) -> anyhow::Result<()> { - if matches!( - *self.status().borrow(), - Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. } - ) { - return Ok(()); - } + let was_disconnected = match *self.status().borrow() { + Status::Disconnected => true, + Status::Connected { .. } + | Status::Connecting { .. } + | Status::Reconnecting { .. } + | Status::Reauthenticating => return Ok(()), + _ => false, + }; - self.set_status(Status::Authenticating, cx); + if was_disconnected { + self.set_status(Status::Authenticating, cx); + } else { + self.set_status(Status::Reauthenticating, cx) + } let (user_id, access_token) = match self.authenticate(&cx).await { Ok(result) => result, @@ -244,27 +254,25 @@ impl Client { } }; - self.set_status(Status::Connecting { user_id }, cx); - - let conn = match self.connect(user_id, &access_token, cx).await { - Ok(conn) => conn, + if was_disconnected { + self.set_status(Status::Connecting { user_id }, cx); + } else { + self.set_status(Status::Reconnecting { user_id }, cx); + } + match self.connect(user_id, &access_token, cx).await { + Ok(conn) => { + log::info!("connected to rpc address {}", *ZED_SERVER_URL); + self.set_connection(user_id, conn, cx).await; + Ok(()) + } Err(err) => { self.set_status(Status::ConnectionError, cx); - return Err(err); + Err(err) } - }; - - self.set_connection(user_id, conn, cx).await?; - log::info!("connected to rpc address {}", *ZED_SERVER_URL); - Ok(()) + } } - pub async fn set_connection( - self: &Arc, - user_id: u64, - conn: Conn, - cx: &AsyncAppContext, - ) -> Result<()> { + async fn set_connection(self: &Arc, user_id: u64, conn: Conn, cx: &AsyncAppContext) { let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await; cx.foreground() .spawn({ @@ -321,7 +329,6 @@ impl Client { } }) .detach(); - Ok(()) } fn authenticate(self: &Arc, cx: &AsyncAppContext) -> Task> { @@ -489,35 +496,6 @@ impl Client { } } -pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone { - type Output: 'a + Future>; - - fn handle( - &self, - message: TypedEnvelope, - rpc: &'a Client, - cx: &'a mut gpui::AsyncAppContext, - ) -> Self::Output; -} - -impl<'a, M, F, Fut> MessageHandler<'a, M> for F -where - M: proto::EnvelopedMessage, - F: Clone + Fn(TypedEnvelope, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut, - Fut: 'a + Future>, -{ - type Output = Fut; - - fn handle( - &self, - message: TypedEnvelope, - rpc: &'a Client, - cx: &'a mut gpui::AsyncAppContext, - ) -> Self::Output { - (self)(message, rpc, cx) - } -} - const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/"; pub fn encode_worktree_url(id: u64, access_token: &str) -> String { @@ -550,6 +528,8 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_heartbeat(cx: TestAppContext) { + cx.foreground().forbid_parking(); + let user_id = 5; let mut client = Client::new(); let server = FakeServer::for_client(user_id, &mut client, &cx).await; @@ -568,6 +548,28 @@ mod tests { assert!(server.receive::().await.is_err()); } + #[gpui::test(iterations = 10)] + async fn test_reconnection(cx: TestAppContext) { + cx.foreground().forbid_parking(); + + let user_id = 5; + let mut client = Client::new(); + let server = FakeServer::for_client(user_id, &mut client, &cx).await; + let mut status = client.status(); + assert!(matches!( + status.recv().await, + Some(Status::Connected { .. }) + )); + + server.forbid_connections(); + server.disconnect().await; + while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {} + + server.allow_connections(); + cx.foreground().advance_clock(Duration::from_secs(10)); + while !matches!(status.recv().await, Some(Status::Connected { .. })) {} + } + #[test] fn test_encode_and_decode_worktree_url() { let url = encode_worktree_url(5, "deadbeef"); diff --git a/zed/src/test.rs b/zed/src/test.rs index bee1537b9dd3cfa019691b8558d12e2f1a380d94..cf1fbfd9e8b04439be6dbe949f86bb20ad4284d7 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -17,7 +17,10 @@ use smol::channel; use std::{ marker::PhantomData, path::{Path, PathBuf}, - sync::Arc, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, }; use tempdir::TempDir; use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope}; @@ -200,6 +203,7 @@ pub struct FakeServer { peer: Arc, incoming: Mutex>>>, connection_id: Mutex>, + forbid_connections: AtomicBool, } impl FakeServer { @@ -212,6 +216,7 @@ impl FakeServer { peer: Peer::new(), incoming: Default::default(), connection_id: Default::default(), + forbid_connections: Default::default(), }); Arc::get_mut(client) @@ -230,15 +235,14 @@ impl FakeServer { assert_eq!(access_token, "the-token"); cx.spawn({ let server = server.clone(); - move |cx| async move { Ok(server.connect(&cx).await) } + move |cx| async move { server.connect(&cx).await } }) } }, ); - let conn = result.connect(&cx.to_async()).await; client - .set_connection(client_user_id, conn, &cx.to_async()) + .authenticate_and_connect(&cx.to_async()) .await .unwrap(); result @@ -250,13 +254,25 @@ impl FakeServer { self.incoming.lock().take(); } - async fn connect(&self, cx: &AsyncAppContext) -> Conn { - let (client_conn, server_conn) = Conn::in_memory(); - let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; - cx.background().spawn(io).detach(); - *self.incoming.lock() = Some(incoming); - *self.connection_id.lock() = Some(connection_id); - client_conn + async fn connect(&self, cx: &AsyncAppContext) -> Result { + if self.forbid_connections.load(SeqCst) { + Err(anyhow!("server is forbidding connections")) + } else { + let (client_conn, server_conn) = Conn::in_memory(); + let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; + cx.background().spawn(io).detach(); + *self.incoming.lock() = Some(incoming); + *self.connection_id.lock() = Some(connection_id); + Ok(client_conn) + } + } + + pub fn forbid_connections(&self) { + self.forbid_connections.store(true, SeqCst); + } + + pub fn allow_connections(&self) { + self.forbid_connections.store(false, SeqCst); } pub async fn send(&self, message: T) {