@@ -21,12 +21,14 @@ use std::{
};
type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
+type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct ConnectionId(u32);
struct Connection {
writer: Mutex<MessageStream<BoxedWriter>>,
+ reader: Mutex<MessageStream<BoxedReader>>,
response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
next_message_id: AtomicU32,
}
@@ -52,7 +54,8 @@ impl<T> TypedEnvelope<T> {
}
pub struct Peer {
- connections: RwLock<HashMap<ConnectionId, (Arc<Connection>, barrier::Sender)>>,
+ connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
+ connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
message_handlers: RwLock<Vec<MessageHandler>>,
handler_types: Mutex<HashSet<TypeId>>,
next_connection_id: AtomicU32,
@@ -62,6 +65,7 @@ impl Peer {
pub fn new() -> Arc<Self> {
Arc::new(Self {
connections: Default::default(),
+ connection_close_barriers: Default::default(),
message_handlers: Default::default(),
handler_types: Default::default(),
next_connection_id: Default::default(),
@@ -102,10 +106,7 @@ impl Peer {
rx
}
- pub async fn add_connection<Conn>(
- self: &Arc<Self>,
- conn: Conn,
- ) -> (ConnectionId, impl Future<Output = Result<()>>)
+ pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
where
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
@@ -113,26 +114,44 @@ impl Peer {
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
);
- let (close_tx, mut close_rx) = barrier::channel();
- let connection = Arc::new(Connection {
- writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
- response_channels: Default::default(),
- next_message_id: Default::default(),
- });
+ self.connections.write().await.insert(
+ connection_id,
+ Arc::new(Connection {
+ reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
+ writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
+ response_channels: Default::default(),
+ next_message_id: Default::default(),
+ }),
+ );
+ connection_id
+ }
- self.connections
+ pub async fn disconnect(&self, connection_id: ConnectionId) {
+ self.connections.write().await.remove(&connection_id);
+ self.connection_close_barriers
.write()
.await
- .insert(connection_id, (connection.clone(), close_tx));
+ .remove(&connection_id);
+ }
+ pub fn handle_messages(
+ self: &Arc<Self>,
+ connection_id: ConnectionId,
+ ) -> impl Future<Output = Result<()>> + 'static {
+ let (close_tx, mut close_rx) = barrier::channel();
let this = self.clone();
- let handler_future = async move {
+ async move {
+ this.connection_close_barriers
+ .write()
+ .await
+ .insert(connection_id, close_tx);
+ let connection = this.connection(connection_id).await?;
let closed = close_rx.recv();
futures::pin_mut!(closed);
- let mut stream = MessageStream::new(conn);
loop {
- let read_message = stream.read_message();
+ let mut reader = connection.reader.lock().await;
+ let read_message = reader.read_message();
futures::pin_mut!(read_message);
match futures::future::select(read_message, &mut closed).await {
@@ -181,13 +200,23 @@ impl Peer {
Either::Right(_) => return Ok(()),
}
}
- };
-
- (connection_id, handler_future)
+ }
}
- pub async fn disconnect(&self, connection_id: ConnectionId) {
- self.connections.write().await.remove(&connection_id);
+ pub async fn receive<M: EnvelopedMessage>(
+ self: &Arc<Self>,
+ connection_id: ConnectionId,
+ ) -> Result<TypedEnvelope<M>> {
+ let connection = self.connection(connection_id).await?;
+ let envelope = connection.reader.lock().await.read_message().await?;
+ let id = envelope.id;
+ let payload =
+ M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
+ Ok(TypedEnvelope {
+ id,
+ connection_id,
+ payload,
+ })
}
pub fn request<T: RequestMessage>(
@@ -271,7 +300,6 @@ impl Peer {
.await
.get(&id)
.ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
- .0
.clone())
}
}
@@ -298,22 +326,22 @@ mod tests {
let server = Peer::new();
let client1 = Peer::new();
let client2 = Peer::new();
- let (client1_conn_id, f1) = client1
+ let client1_conn_id = client1
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await;
- let (client2_conn_id, f2) = client2
+ let client2_conn_id = client2
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await;
- let (_, f3) = server
+ let server_conn_id1 = server
.add_connection(listener.accept().await.unwrap().0)
.await;
- let (_, f4) = server
+ let server_conn_id2 = server
.add_connection(listener.accept().await.unwrap().0)
.await;
- smol::spawn(f1).detach();
- smol::spawn(f2).detach();
- smol::spawn(f3).detach();
- smol::spawn(f4).detach();
+ smol::spawn(client1.handle_messages(client1_conn_id)).detach();
+ smol::spawn(client2.handle_messages(client2_conn_id)).detach();
+ smol::spawn(server.handle_messages(server_conn_id1)).detach();
+ smol::spawn(server.handle_messages(server_conn_id2)).detach();
// define the expected requests and responses
let request1 = proto::OpenWorktree {
@@ -428,21 +456,21 @@ mod tests {
let (mut server_conn, _) = listener.accept().await.unwrap();
let client = Peer::new();
- let (connection_id, handler) = client.add_connection(client_conn).await;
- smol::spawn(handler).detach();
+ let connection_id = client.add_connection(client_conn).await;
+ let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
+ barrier::channel();
+ let handle_messages = client.handle_messages(connection_id);
+ smol::spawn(async move {
+ handle_messages.await.unwrap();
+ incoming_messages_ended_tx.send(()).await.unwrap();
+ })
+ .detach();
client.disconnect(connection_id).await;
- // Try sending an empty payload over and over, until the client is dropped and hangs up.
- loop {
- match server_conn.write(&[]).await {
- Ok(_) => {}
- Err(err) => {
- if err.kind() == io::ErrorKind::BrokenPipe {
- break;
- }
- }
- }
- }
+ incoming_messages_ended_rx.recv().await;
+
+ let err = server_conn.write(&[]).await.unwrap_err();
+ assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
});
}
@@ -456,8 +484,8 @@ mod tests {
client_conn.close().await.unwrap();
let client = Peer::new();
- let (connection_id, handler) = client.add_connection(client_conn).await;
- smol::spawn(handler).detach();
+ let connection_id = client.add_connection(client_conn).await;
+ smol::spawn(client.handle_messages(connection_id)).detach();
let err = client
.request(