@@ -1011,16 +1011,24 @@ mod tests {
};
use async_std::{sync::RwLockReadGuard, task};
use gpui::TestAppContext;
- use postage::mpsc;
+ use parking_lot::Mutex;
+ use postage::{mpsc, watch};
use serde_json::json;
use sqlx::types::time::OffsetDateTime;
- use std::{path::Path, sync::Arc, time::Duration};
+ use std::{
+ path::Path,
+ sync::{
+ atomic::{AtomicBool, Ordering::SeqCst},
+ Arc,
+ },
+ time::Duration,
+ };
use zed::{
channel::{Channel, ChannelDetails, ChannelList},
editor::{Editor, Insert},
fs::{FakeFs, Fs as _},
language::LanguageRegistry,
- rpc::Client,
+ rpc::{self, Client},
settings,
user::UserStore,
worktree::Worktree,
@@ -1677,11 +1685,168 @@ mod tests {
);
}
+ #[gpui::test]
+ async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
+ cx_a.foreground().forbid_parking();
+
+ // Connect to a server as 2 clients.
+ let mut server = TestServer::start().await;
+ let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
+ let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
+ let mut status_b = client_b.status();
+
+ // Create an org that includes these 2 users.
+ let db = &server.app_state.db;
+ let org_id = db.create_org("Test Org", "test-org").await.unwrap();
+ db.add_org_member(org_id, user_id_a, false).await.unwrap();
+ db.add_org_member(org_id, user_id_b, false).await.unwrap();
+
+ // Create a channel that includes all the users.
+ let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
+ db.add_channel_member(channel_id, user_id_a, false)
+ .await
+ .unwrap();
+ db.add_channel_member(channel_id, user_id_b, false)
+ .await
+ .unwrap();
+ db.create_channel_message(
+ channel_id,
+ user_id_b,
+ "hello A, it's B.",
+ OffsetDateTime::now_utc(),
+ )
+ .await
+ .unwrap();
+
+ let user_store_a = Arc::new(UserStore::new(client_a.clone()));
+ let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
+ channels_a
+ .condition(&mut cx_a, |list, _| list.available_channels().is_some())
+ .await;
+
+ channels_a.read_with(&cx_a, |list, _| {
+ assert_eq!(
+ list.available_channels().unwrap(),
+ &[ChannelDetails {
+ id: channel_id.to_proto(),
+ name: "test-channel".to_string()
+ }]
+ )
+ });
+ let channel_a = channels_a.update(&mut cx_a, |this, cx| {
+ this.get_channel(channel_id.to_proto(), cx).unwrap()
+ });
+ channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
+ channel_a
+ .condition(&cx_a, |channel, _| {
+ channel_messages(channel)
+ == [("user_b".to_string(), "hello A, it's B.".to_string())]
+ })
+ .await;
+
+ let user_store_b = Arc::new(UserStore::new(client_b.clone()));
+ let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
+ channels_b
+ .condition(&mut cx_b, |list, _| list.available_channels().is_some())
+ .await;
+ channels_b.read_with(&cx_b, |list, _| {
+ assert_eq!(
+ list.available_channels().unwrap(),
+ &[ChannelDetails {
+ id: channel_id.to_proto(),
+ name: "test-channel".to_string()
+ }]
+ )
+ });
+
+ let channel_b = channels_b.update(&mut cx_b, |this, cx| {
+ this.get_channel(channel_id.to_proto(), cx).unwrap()
+ });
+ channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
+ channel_b
+ .condition(&cx_b, |channel, _| {
+ channel_messages(channel)
+ == [("user_b".to_string(), "hello A, it's B.".to_string())]
+ })
+ .await;
+
+ // Disconnect client B, ensuring we can still access its cached channel data.
+ server.forbid_connections();
+ server.disconnect_client(user_id_b);
+ while !matches!(
+ status_b.recv().await,
+ Some(rpc::Status::ReconnectionError { .. })
+ ) {}
+
+ channels_b.read_with(&cx_b, |channels, _| {
+ assert_eq!(
+ channels.available_channels().unwrap(),
+ [ChannelDetails {
+ id: channel_id.to_proto(),
+ name: "test-channel".to_string()
+ }]
+ )
+ });
+ channel_b.read_with(&cx_b, |channel, _| {
+ assert_eq!(
+ channel_messages(channel),
+ [("user_b".to_string(), "hello A, it's B.".to_string())]
+ )
+ });
+
+ // Send a message from client A while B is disconnected.
+ channel_a
+ .update(&mut cx_a, |channel, cx| {
+ channel
+ .send_message("oh, hi B.".to_string(), cx)
+ .unwrap()
+ .detach();
+ let task = channel.send_message("sup".to_string(), cx).unwrap();
+ assert_eq!(
+ channel
+ .pending_messages()
+ .iter()
+ .map(|m| &m.body)
+ .collect::<Vec<_>>(),
+ &["oh, hi B.", "sup"]
+ );
+ task
+ })
+ .await
+ .unwrap();
+
+ // Give client B a chance to reconnect.
+ server.allow_connections();
+ cx_b.foreground().advance_clock(Duration::from_secs(10));
+
+ // Verify that B sees the new messages upon reconnection.
+ channel_b
+ .condition(&cx_b, |channel, _| {
+ channel_messages(channel)
+ == [
+ ("user_b".to_string(), "hello A, it's B.".to_string()),
+ ("user_a".to_string(), "oh, hi B.".to_string()),
+ ("user_a".to_string(), "sup".to_string()),
+ ]
+ })
+ .await;
+
+ fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
+ channel
+ .messages()
+ .cursor::<(), ()>()
+ .map(|m| (m.sender.github_login.clone(), m.body.clone()))
+ .collect()
+ }
+ }
+
struct TestServer {
peer: Arc<Peer>,
app_state: Arc<AppState>,
server: Arc<Server>,
notifications: mpsc::Receiver<()>,
+ connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
+ forbid_connections: Arc<AtomicBool>,
_test_db: TestDb,
}
@@ -1697,6 +1862,8 @@ mod tests {
app_state,
server,
notifications: notifications.1,
+ connection_killers: Default::default(),
+ forbid_connections: Default::default(),
_test_db: test_db,
}
}
@@ -1710,6 +1877,8 @@ mod tests {
let client_name = name.to_string();
let mut client = Client::new();
let server = self.server.clone();
+ let connection_killers = self.connection_killers.clone();
+ let forbid_connections = self.forbid_connections.clone();
Arc::get_mut(&mut client)
.unwrap()
.set_login_and_connect_callbacks(
@@ -1719,15 +1888,20 @@ mod tests {
Ok((client_user_id.0 as u64, access_token))
})
},
- {
- move |user_id, access_token, cx| {
- assert_eq!(user_id, client_user_id.0 as u64);
- assert_eq!(access_token, "the-token");
-
- let server = server.clone();
- let client_name = client_name.clone();
- cx.spawn(move |cx| async move {
- let (client_conn, server_conn) = Conn::in_memory();
+ move |user_id, access_token, cx| {
+ assert_eq!(user_id, client_user_id.0 as u64);
+ assert_eq!(access_token, "the-token");
+
+ let server = server.clone();
+ let connection_killers = connection_killers.clone();
+ let forbid_connections = forbid_connections.clone();
+ let client_name = client_name.clone();
+ cx.spawn(move |cx| async move {
+ if forbid_connections.load(SeqCst) {
+ Err(anyhow!("server is forbidding connections"))
+ } else {
+ let (client_conn, server_conn, kill_conn) = Conn::in_memory();
+ connection_killers.lock().insert(client_user_id, kill_conn);
cx.background()
.spawn(server.handle_connection(
server_conn,
@@ -1736,8 +1910,8 @@ mod tests {
))
.detach();
Ok(client_conn)
- })
- }
+ }
+ })
},
);
@@ -1748,6 +1922,20 @@ mod tests {
(client_user_id, client)
}
+ fn disconnect_client(&self, user_id: UserId) {
+ if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
+ let _ = kill_conn.try_send(Some(()));
+ }
+ }
+
+ fn forbid_connections(&self) {
+ self.forbid_connections.store(true, SeqCst);
+ }
+
+ fn allow_connections(&self) {
+ self.forbid_connections.store(false, SeqCst);
+ }
+
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
let mut config = Config::default();
config.session_secret = "a".repeat(32);
@@ -33,22 +33,55 @@ impl Conn {
}
#[cfg(any(test, feature = "test-support"))]
- pub fn in_memory() -> (Self, Self) {
- use futures::SinkExt as _;
- use futures::StreamExt as _;
- use std::io::{Error, ErrorKind};
+ pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
+ let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
+ postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
- let (a_tx, a_rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
- let (b_tx, b_rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
+ let (a_tx, a_rx) = Self::channel(kill_rx.clone());
+ let (b_tx, b_rx) = Self::channel(kill_rx);
(
- Self {
- tx: Box::new(a_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())),
- rx: Box::new(b_rx.map(Ok)),
- },
- Self {
- tx: Box::new(b_tx.sink_map_err(|e| Error::new(ErrorKind::Other, e).into())),
- rx: Box::new(a_rx.map(Ok)),
- },
+ Self { tx: a_tx, rx: b_rx },
+ Self { tx: b_tx, rx: a_rx },
+ kill_tx,
)
}
+
+ #[cfg(any(test, feature = "test-support"))]
+ fn channel(
+ kill_rx: postage::watch::Receiver<Option<()>>,
+ ) -> (
+ Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
+ Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
+ ) {
+ use futures::{future, stream, SinkExt as _, StreamExt as _};
+ use std::io::{Error, ErrorKind};
+
+ let (tx, rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
+ let tx = tx
+ .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
+ .with({
+ let kill_rx = kill_rx.clone();
+ move |msg| {
+ if kill_rx.borrow().is_none() {
+ future::ready(Ok(msg))
+ } else {
+ future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
+ }
+ }
+ });
+ let rx = stream::select(
+ rx.map(Ok),
+ kill_rx.filter_map(|kill| {
+ if let Some(_) = kill {
+ future::ready(Some(Err(
+ Error::new(ErrorKind::Other, "connection killed").into()
+ )))
+ } else {
+ future::ready(None)
+ }
+ }),
+ );
+
+ (Box::new(tx), Box::new(rx))
+ }
}
@@ -352,12 +352,12 @@ mod tests {
let client1 = Peer::new();
let client2 = Peer::new();
- let (client1_to_server_conn, server_to_client_1_conn) = Conn::in_memory();
+ let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory();
let (client1_conn_id, io_task1, _) =
client1.add_connection(client1_to_server_conn).await;
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
- let (client2_to_server_conn, server_to_client_2_conn) = Conn::in_memory();
+ let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory();
let (client2_conn_id, io_task3, _) =
client2.add_connection(client2_to_server_conn).await;
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
@@ -492,7 +492,7 @@ mod tests {
#[test]
fn test_disconnect() {
smol::block_on(async move {
- let (client_conn, mut server_conn) = Conn::in_memory();
+ let (client_conn, mut server_conn, _) = Conn::in_memory();
let client = Peer::new();
let (connection_id, io_handler, mut incoming) =
@@ -526,7 +526,7 @@ mod tests {
#[test]
fn test_io_error() {
smol::block_on(async move {
- let (client_conn, server_conn) = Conn::in_memory();
+ let (client_conn, server_conn, _) = Conn::in_memory();
drop(server_conn);
let client = Peer::new();