diff --git a/server/src/rpc.rs b/server/src/rpc.rs index f4d16af996139e596685cc6b5ec25b93655a4d3a..cc0d35d097a2bf860947ead3eb774e8eaf1b9c91 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1011,16 +1011,24 @@ mod tests { }; use async_std::{sync::RwLockReadGuard, task}; use gpui::TestAppContext; - use postage::mpsc; + use parking_lot::Mutex; + use postage::{mpsc, watch}; use serde_json::json; use sqlx::types::time::OffsetDateTime; - use std::{path::Path, sync::Arc, time::Duration}; + use std::{ + path::Path, + sync::{ + atomic::{AtomicBool, Ordering::SeqCst}, + Arc, + }, + time::Duration, + }; use zed::{ channel::{Channel, ChannelDetails, ChannelList}, editor::{Editor, Insert}, fs::{FakeFs, Fs as _}, language::LanguageRegistry, - rpc::Client, + rpc::{self, Client}, settings, user::UserStore, worktree::Worktree, @@ -1677,11 +1685,168 @@ mod tests { ); } + #[gpui::test] + async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { + cx_a.foreground().forbid_parking(); + + // Connect to a server as 2 clients. + let mut server = TestServer::start().await; + let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await; + let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await; + let mut status_b = client_b.status(); + + // Create an org that includes these 2 users. + let db = &server.app_state.db; + let org_id = db.create_org("Test Org", "test-org").await.unwrap(); + db.add_org_member(org_id, user_id_a, false).await.unwrap(); + db.add_org_member(org_id, user_id_b, false).await.unwrap(); + + // Create a channel that includes all the users. + let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap(); + db.add_channel_member(channel_id, user_id_a, false) + .await + .unwrap(); + db.add_channel_member(channel_id, user_id_b, false) + .await + .unwrap(); + db.create_channel_message( + channel_id, + user_id_b, + "hello A, it's B.", + OffsetDateTime::now_utc(), + ) + .await + .unwrap(); + + let user_store_a = Arc::new(UserStore::new(client_a.clone())); + let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx)); + channels_a + .condition(&mut cx_a, |list, _| list.available_channels().is_some()) + .await; + + channels_a.read_with(&cx_a, |list, _| { + assert_eq!( + list.available_channels().unwrap(), + &[ChannelDetails { + id: channel_id.to_proto(), + name: "test-channel".to_string() + }] + ) + }); + let channel_a = channels_a.update(&mut cx_a, |this, cx| { + this.get_channel(channel_id.to_proto(), cx).unwrap() + }); + channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty())); + channel_a + .condition(&cx_a, |channel, _| { + channel_messages(channel) + == [("user_b".to_string(), "hello A, it's B.".to_string())] + }) + .await; + + let user_store_b = Arc::new(UserStore::new(client_b.clone())); + let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx)); + channels_b + .condition(&mut cx_b, |list, _| list.available_channels().is_some()) + .await; + channels_b.read_with(&cx_b, |list, _| { + assert_eq!( + list.available_channels().unwrap(), + &[ChannelDetails { + id: channel_id.to_proto(), + name: "test-channel".to_string() + }] + ) + }); + + let channel_b = channels_b.update(&mut cx_b, |this, cx| { + this.get_channel(channel_id.to_proto(), cx).unwrap() + }); + channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty())); + channel_b + .condition(&cx_b, |channel, _| { + channel_messages(channel) + == [("user_b".to_string(), "hello A, it's B.".to_string())] + }) + .await; + + // Disconnect client B, ensuring we can still access its cached channel data. + server.forbid_connections(); + server.disconnect_client(user_id_b); + while !matches!( + status_b.recv().await, + Some(rpc::Status::ReconnectionError { .. }) + ) {} + + channels_b.read_with(&cx_b, |channels, _| { + assert_eq!( + channels.available_channels().unwrap(), + [ChannelDetails { + id: channel_id.to_proto(), + name: "test-channel".to_string() + }] + ) + }); + channel_b.read_with(&cx_b, |channel, _| { + assert_eq!( + channel_messages(channel), + [("user_b".to_string(), "hello A, it's B.".to_string())] + ) + }); + + // Send a message from client A while B is disconnected. + channel_a + .update(&mut cx_a, |channel, cx| { + channel + .send_message("oh, hi B.".to_string(), cx) + .unwrap() + .detach(); + let task = channel.send_message("sup".to_string(), cx).unwrap(); + assert_eq!( + channel + .pending_messages() + .iter() + .map(|m| &m.body) + .collect::>(), + &["oh, hi B.", "sup"] + ); + task + }) + .await + .unwrap(); + + // Give client B a chance to reconnect. + server.allow_connections(); + cx_b.foreground().advance_clock(Duration::from_secs(10)); + + // Verify that B sees the new messages upon reconnection. + channel_b + .condition(&cx_b, |channel, _| { + channel_messages(channel) + == [ + ("user_b".to_string(), "hello A, it's B.".to_string()), + ("user_a".to_string(), "oh, hi B.".to_string()), + ("user_a".to_string(), "sup".to_string()), + ] + }) + .await; + + fn channel_messages(channel: &Channel) -> Vec<(String, String)> { + channel + .messages() + .cursor::<(), ()>() + .map(|m| (m.sender.github_login.clone(), m.body.clone())) + .collect() + } + } + struct TestServer { peer: Arc, app_state: Arc, server: Arc, notifications: mpsc::Receiver<()>, + connection_killers: Arc>>>>, + forbid_connections: Arc, _test_db: TestDb, } @@ -1697,6 +1862,8 @@ mod tests { app_state, server, notifications: notifications.1, + connection_killers: Default::default(), + forbid_connections: Default::default(), _test_db: test_db, } } @@ -1710,6 +1877,8 @@ mod tests { let client_name = name.to_string(); let mut client = Client::new(); let server = self.server.clone(); + let connection_killers = self.connection_killers.clone(); + let forbid_connections = self.forbid_connections.clone(); Arc::get_mut(&mut client) .unwrap() .set_login_and_connect_callbacks( @@ -1719,15 +1888,20 @@ mod tests { Ok((client_user_id.0 as u64, access_token)) }) }, - { - move |user_id, access_token, cx| { - assert_eq!(user_id, client_user_id.0 as u64); - assert_eq!(access_token, "the-token"); - - let server = server.clone(); - let client_name = client_name.clone(); - cx.spawn(move |cx| async move { - let (client_conn, server_conn) = Conn::in_memory(); + move |user_id, access_token, cx| { + assert_eq!(user_id, client_user_id.0 as u64); + assert_eq!(access_token, "the-token"); + + let server = server.clone(); + let connection_killers = connection_killers.clone(); + let forbid_connections = forbid_connections.clone(); + let client_name = client_name.clone(); + cx.spawn(move |cx| async move { + if forbid_connections.load(SeqCst) { + Err(anyhow!("server is forbidding connections")) + } else { + let (client_conn, server_conn, kill_conn) = Conn::in_memory(); + connection_killers.lock().insert(client_user_id, kill_conn); cx.background() .spawn(server.handle_connection( server_conn, @@ -1736,8 +1910,8 @@ mod tests { )) .detach(); Ok(client_conn) - }) - } + } + }) }, ); @@ -1748,6 +1922,20 @@ mod tests { (client_user_id, client) } + fn disconnect_client(&self, user_id: UserId) { + if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) { + let _ = kill_conn.try_send(Some(())); + } + } + + fn forbid_connections(&self) { + self.forbid_connections.store(true, SeqCst); + } + + fn allow_connections(&self) { + self.forbid_connections.store(false, SeqCst); + } + async fn build_app_state(test_db: &TestDb) -> Arc { let mut config = Config::default(); config.session_secret = "a".repeat(32); diff --git a/zed/src/test.rs b/zed/src/test.rs index cf1fbfd9e8b04439be6dbe949f86bb20ad4284d7..ce865bbfe58d64c267267cbfbe85321a87a7ca37 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -258,7 +258,7 @@ impl FakeServer { if self.forbid_connections.load(SeqCst) { Err(anyhow!("server is forbidding connections")) } else { - let (client_conn, server_conn) = Conn::in_memory(); + let (client_conn, server_conn, _) = Conn::in_memory(); let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; cx.background().spawn(io).detach(); *self.incoming.lock() = Some(incoming); diff --git a/zrpc/src/conn.rs b/zrpc/src/conn.rs index 06dbcee0774e23dbbcfb6654b44673349f046b68..f25fb12f01f651f77483e8ac021166246ae18ff0 100644 --- a/zrpc/src/conn.rs +++ b/zrpc/src/conn.rs @@ -33,22 +33,55 @@ impl Conn { } #[cfg(any(test, feature = "test-support"))] - pub fn in_memory() -> (Self, Self) { - use futures::SinkExt as _; - use futures::StreamExt as _; - use std::io::{Error, ErrorKind}; + pub fn in_memory() -> (Self, Self, postage::watch::Sender>) { + let (kill_tx, mut kill_rx) = postage::watch::channel_with(None); + postage::stream::Stream::try_recv(&mut kill_rx).unwrap(); - let (a_tx, a_rx) = futures::channel::mpsc::unbounded::(); - let (b_tx, b_rx) = futures::channel::mpsc::unbounded::(); + let (a_tx, a_rx) = Self::channel(kill_rx.clone()); + let (b_tx, b_rx) = Self::channel(kill_rx); ( - Self { - tx: Box::new(a_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())), - rx: Box::new(b_rx.map(Ok)), - }, - Self { - tx: Box::new(b_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())), - rx: Box::new(a_rx.map(Ok)), - }, + Self { tx: a_tx, rx: b_rx }, + Self { tx: b_tx, rx: a_rx }, + kill_tx, ) } + + #[cfg(any(test, feature = "test-support"))] + fn channel( + kill_rx: postage::watch::Receiver>, + ) -> ( + Box>, + Box>>, + ) { + use futures::{future, stream, SinkExt as _, StreamExt as _}; + use std::io::{Error, ErrorKind}; + + let (tx, rx) = futures::channel::mpsc::unbounded::(); + let tx = tx + .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e))) + .with({ + let kill_rx = kill_rx.clone(); + move |msg| { + if kill_rx.borrow().is_none() { + future::ready(Ok(msg)) + } else { + future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into())) + } + } + }); + let rx = stream::select( + rx.map(Ok), + kill_rx.filter_map(|kill| { + if let Some(_) = kill { + future::ready(Some(Err( + Error::new(ErrorKind::Other, "connection killed").into() + ))) + } else { + future::ready(None) + } + }), + ); + + (Box::new(tx), Box::new(rx)) + } } diff --git a/zrpc/src/peer.rs b/zrpc/src/peer.rs index d50ee50ec3f4e099e153852be634d3470eb8603c..93c413acbcab429552f0ee123c8dee5d3b3e860c 100644 --- a/zrpc/src/peer.rs +++ b/zrpc/src/peer.rs @@ -352,12 +352,12 @@ mod tests { let client1 = Peer::new(); let client2 = Peer::new(); - let (client1_to_server_conn, server_to_client_1_conn) = Conn::in_memory(); + let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory(); let (client1_conn_id, io_task1, _) = client1.add_connection(client1_to_server_conn).await; let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await; - let (client2_to_server_conn, server_to_client_2_conn) = Conn::in_memory(); + let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory(); let (client2_conn_id, io_task3, _) = client2.add_connection(client2_to_server_conn).await; let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await; @@ -492,7 +492,7 @@ mod tests { #[test] fn test_disconnect() { smol::block_on(async move { - let (client_conn, mut server_conn) = Conn::in_memory(); + let (client_conn, mut server_conn, _) = Conn::in_memory(); let client = Peer::new(); let (connection_id, io_handler, mut incoming) = @@ -526,7 +526,7 @@ mod tests { #[test] fn test_io_error() { smol::block_on(async move { - let (client_conn, server_conn) = Conn::in_memory(); + let (client_conn, server_conn, _) = Conn::in_memory(); drop(server_conn); let client = Peer::new();