Detailed changes
@@ -6,6 +6,7 @@ use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
use gpui::{executor, ModelHandle, TestAppContext};
use parking_lot::Mutex;
+use postage::barrier;
use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
use std::{fmt, rc::Rc, sync::Arc};
@@ -22,6 +23,7 @@ struct FakeServerState {
connection_id: Option<ConnectionId>,
forbid_connections: bool,
auth_count: usize,
+ connection_killer: Option<barrier::Sender>,
access_token: usize,
}
@@ -74,13 +76,15 @@ impl FakeServer {
Err(EstablishConnectionError::Unauthorized)?
}
- let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
+ let (client_conn, server_conn, kill) =
+ Connection::in_memory(cx.background());
let (connection_id, io, incoming) =
peer.add_test_connection(server_conn, cx.background()).await;
cx.background().spawn(io).detach();
let mut state = state.lock();
state.connection_id = Some(connection_id);
state.incoming = Some(incoming);
+ state.connection_killer = Some(kill);
Ok(client_conn)
})
}
@@ -1,6 +1,5 @@
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use futures::{SinkExt as _, Stream, StreamExt as _};
-use std::{io, task::Poll};
+use futures::{SinkExt as _, StreamExt as _};
pub struct Connection {
pub(crate) tx:
@@ -36,87 +35,82 @@ impl Connection {
#[cfg(any(test, feature = "test-support"))]
pub fn in_memory(
executor: std::sync::Arc<gpui::executor::Background>,
- ) -> (Self, Self, postage::watch::Sender<Option<()>>) {
- let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
- postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
+ ) -> (Self, Self, postage::barrier::Sender) {
+ use postage::prelude::Stream;
- let (a_tx, a_rx) = Self::channel(kill_rx.clone(), executor.clone());
- let (b_tx, b_rx) = Self::channel(kill_rx, executor);
- (
+ let (kill_tx, kill_rx) = postage::barrier::channel();
+ let (a_tx, a_rx) = channel(kill_rx.clone(), executor.clone());
+ let (b_tx, b_rx) = channel(kill_rx, executor);
+ return (
Self { tx: a_tx, rx: b_rx },
Self { tx: b_tx, rx: a_rx },
kill_tx,
- )
- }
+ );
- #[cfg(any(test, feature = "test-support"))]
- fn channel(
- kill_rx: postage::watch::Receiver<Option<()>>,
- executor: std::sync::Arc<gpui::executor::Background>,
- ) -> (
- Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
- Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
- ) {
- use futures::channel::mpsc;
- use io::{Error, ErrorKind};
- use std::sync::Arc;
+ fn channel(
+ kill_rx: postage::barrier::Receiver,
+ executor: std::sync::Arc<gpui::executor::Background>,
+ ) -> (
+ Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
+ Box<
+ dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
+ >,
+ ) {
+ use futures::channel::mpsc;
+ use std::{
+ io::{Error, ErrorKind},
+ sync::Arc,
+ };
- let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
- let tx = tx
- .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
- .with({
- let executor = Arc::downgrade(&executor);
+ let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
+
+ let tx = tx
+ .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
+ .with({
+ let kill_rx = kill_rx.clone();
+ let executor = Arc::downgrade(&executor);
+ move |msg| {
+ let mut kill_rx = kill_rx.clone();
+ let executor = executor.clone();
+ Box::pin(async move {
+ if let Some(executor) = executor.upgrade() {
+ executor.simulate_random_delay().await;
+ }
+
+ // Writes to a half-open TCP connection will error.
+ if kill_rx.try_recv().is_ok() {
+ std::io::Result::Err(
+ Error::new(ErrorKind::Other, "connection lost").into(),
+ )?;
+ }
+
+ Ok(msg)
+ })
+ }
+ });
+
+ let rx = rx.then({
let kill_rx = kill_rx.clone();
+ let executor = Arc::downgrade(&executor);
move |msg| {
- let kill_rx = kill_rx.clone();
+ let mut kill_rx = kill_rx.clone();
let executor = executor.clone();
Box::pin(async move {
if let Some(executor) = executor.upgrade() {
executor.simulate_random_delay().await;
}
- if kill_rx.borrow().is_none() {
- Ok(msg)
- } else {
- Err(Error::new(ErrorKind::Other, "connection killed").into())
+
+ // Reads from a half-open TCP connection will hang.
+ if kill_rx.try_recv().is_ok() {
+ futures::future::pending::<()>().await;
}
+
+ Ok(msg)
})
}
});
- let rx = rx.then(move |msg| {
- let executor = Arc::downgrade(&executor);
- Box::pin(async move {
- if let Some(executor) = executor.upgrade() {
- executor.simulate_random_delay().await;
- }
- msg
- })
- });
- let rx = KillableReceiver { kill_rx, rx };
-
- (Box::new(tx), Box::new(rx))
- }
-}
-
-struct KillableReceiver<S> {
- rx: S,
- kill_rx: postage::watch::Receiver<Option<()>>,
-}
-
-impl<S: Unpin + Stream<Item = WebSocketMessage>> Stream for KillableReceiver<S> {
- type Item = Result<WebSocketMessage, WebSocketError>;
- fn poll_next(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<Option<Self::Item>> {
- if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
- Poll::Ready(Some(Err(io::Error::new(
- io::ErrorKind::Other,
- "connection killed",
- )
- .into())))
- } else {
- self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
+ (Box::new(tx), Box::new(rx))
}
}
}
@@ -371,7 +371,7 @@ mod tests {
let client1 = Peer::new();
let client2 = Peer::new();
- let (client1_to_server_conn, server_to_client_1_conn, _) =
+ let (client1_to_server_conn, server_to_client_1_conn, _kill) =
Connection::in_memory(cx.background());
let (client1_conn_id, io_task1, client1_incoming) = client1
.add_test_connection(client1_to_server_conn, cx.background())
@@ -380,7 +380,7 @@ mod tests {
.add_test_connection(server_to_client_1_conn, cx.background())
.await;
- let (client2_to_server_conn, server_to_client_2_conn, _) =
+ let (client2_to_server_conn, server_to_client_2_conn, _kill) =
Connection::in_memory(cx.background());
let (client2_conn_id, io_task3, client2_incoming) = client2
.add_test_connection(client2_to_server_conn, cx.background())
@@ -468,7 +468,7 @@ mod tests {
let server = Peer::new();
let client = Peer::new();
- let (client_to_server_conn, server_to_client_conn, _) =
+ let (client_to_server_conn, server_to_client_conn, _kill) =
Connection::in_memory(cx.background());
let (client_to_server_conn_id, io_task1, mut client_incoming) = client
.add_test_connection(client_to_server_conn, cx.background())
@@ -568,7 +568,7 @@ mod tests {
let server = Peer::new();
let client = Peer::new();
- let (client_to_server_conn, server_to_client_conn, _) =
+ let (client_to_server_conn, server_to_client_conn, _kill) =
Connection::in_memory(cx.background());
let (client_to_server_conn_id, io_task1, mut client_incoming) = client
.add_test_connection(client_to_server_conn, cx.background())
@@ -680,7 +680,7 @@ mod tests {
async fn test_disconnect(cx: &mut TestAppContext) {
let executor = cx.foreground();
- let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
+ let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
let client = Peer::new();
let (connection_id, io_handler, mut incoming) = client
@@ -716,7 +716,7 @@ mod tests {
#[gpui::test(iterations = 50)]
async fn test_io_error(cx: &mut TestAppContext) {
let executor = cx.foreground();
- let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
+ let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
let client = Peer::new();
let (connection_id, io_handler, mut incoming) = client
@@ -1030,7 +1030,7 @@ mod tests {
};
use lsp;
use parking_lot::Mutex;
- use postage::{sink::Sink, watch};
+ use postage::{barrier, watch};
use project::{
fs::{FakeFs, Fs as _},
search::SearchQuery,
@@ -1872,6 +1872,7 @@ mod tests {
// Simulate connection loss for client B and ensure client A observes client B leaving the project.
server.disconnect_client(client_b.current_user_id(cx_b));
+ cx_a.foreground().advance_clock(Duration::from_secs(3));
project_a
.condition(&cx_a, |p, _| p.collaborators().len() == 0)
.await;
@@ -3898,6 +3899,7 @@ mod tests {
// Disconnect client B, ensuring we can still access its cached channel data.
server.forbid_connections();
server.disconnect_client(client_b.current_user_id(&cx_b));
+ cx_b.foreground().advance_clock(Duration::from_secs(3));
while !matches!(
status_b.next().await,
Some(client::Status::ReconnectionError { .. })
@@ -4388,7 +4390,7 @@ mod tests {
server: Arc<Server>,
foreground: Rc<executor::Foreground>,
notifications: mpsc::UnboundedReceiver<()>,
- connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
+ connection_killers: Arc<Mutex<HashMap<UserId, barrier::Sender>>>,
forbid_connections: Arc<AtomicBool>,
_test_db: TestDb,
}
@@ -4492,9 +4494,7 @@ mod tests {
}
fn disconnect_client(&self, user_id: UserId) {
- if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
- let _ = kill_conn.try_send(Some(()));
- }
+ self.connection_killers.lock().remove(&user_id);
}
fn forbid_connections(&self) {