Add integration test simulating killing a connection while chatting

Antonio Scandurra created

Change summary

server/src/rpc.rs | 216 +++++++++++++++++++++++++++++++++++++++++++++---
zed/src/test.rs   |   2 
zrpc/src/conn.rs  |  61 ++++++++++---
zrpc/src/peer.rs  |   8 
4 files changed, 254 insertions(+), 33 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -1011,16 +1011,24 @@ mod tests {
     };
     use async_std::{sync::RwLockReadGuard, task};
     use gpui::TestAppContext;
-    use postage::mpsc;
+    use parking_lot::Mutex;
+    use postage::{mpsc, watch};
     use serde_json::json;
     use sqlx::types::time::OffsetDateTime;
-    use std::{path::Path, sync::Arc, time::Duration};
+    use std::{
+        path::Path,
+        sync::{
+            atomic::{AtomicBool, Ordering::SeqCst},
+            Arc,
+        },
+        time::Duration,
+    };
     use zed::{
         channel::{Channel, ChannelDetails, ChannelList},
         editor::{Editor, Insert},
         fs::{FakeFs, Fs as _},
         language::LanguageRegistry,
-        rpc::Client,
+        rpc::{self, Client},
         settings,
         user::UserStore,
         worktree::Worktree,
@@ -1677,11 +1685,168 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
+        cx_a.foreground().forbid_parking();
+
+        // Connect to a server as 2 clients.
+        let mut server = TestServer::start().await;
+        let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
+        let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
+        let mut status_b = client_b.status();
+
+        // Create an org that includes these 2 users.
+        let db = &server.app_state.db;
+        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
+        db.add_org_member(org_id, user_id_a, false).await.unwrap();
+        db.add_org_member(org_id, user_id_b, false).await.unwrap();
+
+        // Create a channel that includes all the users.
+        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
+        db.add_channel_member(channel_id, user_id_a, false)
+            .await
+            .unwrap();
+        db.add_channel_member(channel_id, user_id_b, false)
+            .await
+            .unwrap();
+        db.create_channel_message(
+            channel_id,
+            user_id_b,
+            "hello A, it's B.",
+            OffsetDateTime::now_utc(),
+        )
+        .await
+        .unwrap();
+
+        let user_store_a = Arc::new(UserStore::new(client_a.clone()));
+        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
+        channels_a
+            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
+            .await;
+
+        channels_a.read_with(&cx_a, |list, _| {
+            assert_eq!(
+                list.available_channels().unwrap(),
+                &[ChannelDetails {
+                    id: channel_id.to_proto(),
+                    name: "test-channel".to_string()
+                }]
+            )
+        });
+        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
+            this.get_channel(channel_id.to_proto(), cx).unwrap()
+        });
+        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
+        channel_a
+            .condition(&cx_a, |channel, _| {
+                channel_messages(channel)
+                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
+            })
+            .await;
+
+        let user_store_b = Arc::new(UserStore::new(client_b.clone()));
+        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
+        channels_b
+            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
+            .await;
+        channels_b.read_with(&cx_b, |list, _| {
+            assert_eq!(
+                list.available_channels().unwrap(),
+                &[ChannelDetails {
+                    id: channel_id.to_proto(),
+                    name: "test-channel".to_string()
+                }]
+            )
+        });
+
+        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
+            this.get_channel(channel_id.to_proto(), cx).unwrap()
+        });
+        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
+        channel_b
+            .condition(&cx_b, |channel, _| {
+                channel_messages(channel)
+                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
+            })
+            .await;
+
+        // Disconnect client B, ensuring we can still access its cached channel data.
+        server.forbid_connections();
+        server.disconnect_client(user_id_b);
+        while !matches!(
+            status_b.recv().await,
+            Some(rpc::Status::ReconnectionError { .. })
+        ) {}
+
+        channels_b.read_with(&cx_b, |channels, _| {
+            assert_eq!(
+                channels.available_channels().unwrap(),
+                [ChannelDetails {
+                    id: channel_id.to_proto(),
+                    name: "test-channel".to_string()
+                }]
+            )
+        });
+        channel_b.read_with(&cx_b, |channel, _| {
+            assert_eq!(
+                channel_messages(channel),
+                [("user_b".to_string(), "hello A, it's B.".to_string())]
+            )
+        });
+
+        // Send a message from client A while B is disconnected.
+        channel_a
+            .update(&mut cx_a, |channel, cx| {
+                channel
+                    .send_message("oh, hi B.".to_string(), cx)
+                    .unwrap()
+                    .detach();
+                let task = channel.send_message("sup".to_string(), cx).unwrap();
+                assert_eq!(
+                    channel
+                        .pending_messages()
+                        .iter()
+                        .map(|m| &m.body)
+                        .collect::<Vec<_>>(),
+                    &["oh, hi B.", "sup"]
+                );
+                task
+            })
+            .await
+            .unwrap();
+
+        // Give client B a chance to reconnect.
+        server.allow_connections();
+        cx_b.foreground().advance_clock(Duration::from_secs(10));
+
+        // Verify that B sees the new messages upon reconnection.
+        channel_b
+            .condition(&cx_b, |channel, _| {
+                channel_messages(channel)
+                    == [
+                        ("user_b".to_string(), "hello A, it's B.".to_string()),
+                        ("user_a".to_string(), "oh, hi B.".to_string()),
+                        ("user_a".to_string(), "sup".to_string()),
+                    ]
+            })
+            .await;
+
+        fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
+            channel
+                .messages()
+                .cursor::<(), ()>()
+                .map(|m| (m.sender.github_login.clone(), m.body.clone()))
+                .collect()
+        }
+    }
+
     struct TestServer {
         peer: Arc<Peer>,
         app_state: Arc<AppState>,
         server: Arc<Server>,
         notifications: mpsc::Receiver<()>,
+        connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
+        forbid_connections: Arc<AtomicBool>,
         _test_db: TestDb,
     }
 
@@ -1697,6 +1862,8 @@ mod tests {
                 app_state,
                 server,
                 notifications: notifications.1,
+                connection_killers: Default::default(),
+                forbid_connections: Default::default(),
                 _test_db: test_db,
             }
         }
@@ -1710,6 +1877,8 @@ mod tests {
             let client_name = name.to_string();
             let mut client = Client::new();
             let server = self.server.clone();
+            let connection_killers = self.connection_killers.clone();
+            let forbid_connections = self.forbid_connections.clone();
             Arc::get_mut(&mut client)
                 .unwrap()
                 .set_login_and_connect_callbacks(
@@ -1719,15 +1888,20 @@ mod tests {
                             Ok((client_user_id.0 as u64, access_token))
                         })
                     },
-                    {
-                        move |user_id, access_token, cx| {
-                            assert_eq!(user_id, client_user_id.0 as u64);
-                            assert_eq!(access_token, "the-token");
-
-                            let server = server.clone();
-                            let client_name = client_name.clone();
-                            cx.spawn(move |cx| async move {
-                                let (client_conn, server_conn) = Conn::in_memory();
+                    move |user_id, access_token, cx| {
+                        assert_eq!(user_id, client_user_id.0 as u64);
+                        assert_eq!(access_token, "the-token");
+
+                        let server = server.clone();
+                        let connection_killers = connection_killers.clone();
+                        let forbid_connections = forbid_connections.clone();
+                        let client_name = client_name.clone();
+                        cx.spawn(move |cx| async move {
+                            if forbid_connections.load(SeqCst) {
+                                Err(anyhow!("server is forbidding connections"))
+                            } else {
+                                let (client_conn, server_conn, kill_conn) = Conn::in_memory();
+                                connection_killers.lock().insert(client_user_id, kill_conn);
                                 cx.background()
                                     .spawn(server.handle_connection(
                                         server_conn,
@@ -1736,8 +1910,8 @@ mod tests {
                                     ))
                                     .detach();
                                 Ok(client_conn)
-                            })
-                        }
+                            }
+                        })
                     },
                 );
 
@@ -1748,6 +1922,20 @@ mod tests {
             (client_user_id, client)
         }
 
+        fn disconnect_client(&self, user_id: UserId) {
+            if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
+                let _ = kill_conn.try_send(Some(()));
+            }
+        }
+
+        fn forbid_connections(&self) {
+            self.forbid_connections.store(true, SeqCst);
+        }
+
+        fn allow_connections(&self) {
+            self.forbid_connections.store(false, SeqCst);
+        }
+
         async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
             let mut config = Config::default();
             config.session_secret = "a".repeat(32);

zed/src/test.rs 🔗

@@ -258,7 +258,7 @@ impl FakeServer {
         if self.forbid_connections.load(SeqCst) {
             Err(anyhow!("server is forbidding connections"))
         } else {
-            let (client_conn, server_conn) = Conn::in_memory();
+            let (client_conn, server_conn, _) = Conn::in_memory();
             let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
             cx.background().spawn(io).detach();
             *self.incoming.lock() = Some(incoming);

zrpc/src/conn.rs 🔗

@@ -33,22 +33,55 @@ impl Conn {
     }
 
     #[cfg(any(test, feature = "test-support"))]
-    pub fn in_memory() -> (Self, Self) {
-        use futures::SinkExt as _;
-        use futures::StreamExt as _;
-        use std::io::{Error, ErrorKind};
+    pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
+        let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
+        postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
 
-        let (a_tx, a_rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
-        let (b_tx, b_rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
+        let (a_tx, a_rx) = Self::channel(kill_rx.clone());
+        let (b_tx, b_rx) = Self::channel(kill_rx);
         (
-            Self {
-                tx: Box::new(a_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())),
-                rx: Box::new(b_rx.map(Ok)),
-            },
-            Self {
-                tx: Box::new(b_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())),
-                rx: Box::new(a_rx.map(Ok)),
-            },
+            Self { tx: a_tx, rx: b_rx },
+            Self { tx: b_tx, rx: a_rx },
+            kill_tx,
         )
     }
+
+    #[cfg(any(test, feature = "test-support"))]
+    fn channel(
+        kill_rx: postage::watch::Receiver<Option<()>>,
+    ) -> (
+        Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
+        Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
+    ) {
+        use futures::{future, stream, SinkExt as _, StreamExt as _};
+        use std::io::{Error, ErrorKind};
+
+        let (tx, rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
+        let tx = tx
+            .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
+            .with({
+                let kill_rx = kill_rx.clone();
+                move |msg| {
+                    if kill_rx.borrow().is_none() {
+                        future::ready(Ok(msg))
+                    } else {
+                        future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
+                    }
+                }
+            });
+        let rx = stream::select(
+            rx.map(Ok),
+            kill_rx.filter_map(|kill| {
+                if let Some(_) = kill {
+                    future::ready(Some(Err(
+                        Error::new(ErrorKind::Other, "connection killed").into()
+                    )))
+                } else {
+                    future::ready(None)
+                }
+            }),
+        );
+
+        (Box::new(tx), Box::new(rx))
+    }
 }

zrpc/src/peer.rs 🔗

@@ -352,12 +352,12 @@ mod tests {
             let client1 = Peer::new();
             let client2 = Peer::new();
 
-            let (client1_to_server_conn, server_to_client_1_conn) = Conn::in_memory();
+            let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory();
             let (client1_conn_id, io_task1, _) =
                 client1.add_connection(client1_to_server_conn).await;
             let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
 
-            let (client2_to_server_conn, server_to_client_2_conn) = Conn::in_memory();
+            let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory();
             let (client2_conn_id, io_task3, _) =
                 client2.add_connection(client2_to_server_conn).await;
             let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
@@ -492,7 +492,7 @@ mod tests {
     #[test]
     fn test_disconnect() {
         smol::block_on(async move {
-            let (client_conn, mut server_conn) = Conn::in_memory();
+            let (client_conn, mut server_conn, _) = Conn::in_memory();
 
             let client = Peer::new();
             let (connection_id, io_handler, mut incoming) =
@@ -526,7 +526,7 @@ mod tests {
     #[test]
     fn test_io_error() {
         smol::block_on(async move {
-            let (client_conn, server_conn) = Conn::in_memory();
+            let (client_conn, server_conn, _) = Conn::in_memory();
             drop(server_conn);
 
             let client = Peer::new();