@@ -1,12 +1,9 @@
use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock};
-use futures::{
- future::{BoxFuture, Either},
- AsyncRead, AsyncWrite, FutureExt,
-};
+use futures::{future::BoxFuture, AsyncRead, AsyncWrite, FutureExt};
use postage::{
- barrier, mpsc, oneshot,
+ mpsc,
prelude::{Sink, Stream},
};
use std::{
@@ -15,29 +12,18 @@ use std::{
fmt,
future::Future,
marker::PhantomData,
- pin::Pin,
sync::{
atomic::{self, AtomicU32},
Arc,
},
};
-type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
-type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
-
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct ConnectionId(pub u32);
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct PeerId(pub u32);
-struct Connection {
- writer: Mutex<MessageStream<BoxedWriter>>,
- reader: Mutex<MessageStream<BoxedReader>>,
- response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
- next_message_id: AtomicU32,
-}
-
type MessageHandler = Box<
dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
>;
@@ -74,18 +60,34 @@ impl<T: RequestMessage> TypedEnvelope<T> {
}
pub struct Peer {
- connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
- connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
+ connections: RwLock<HashMap<ConnectionId, Connection>>,
message_handlers: RwLock<Vec<MessageHandler>>,
handler_types: Mutex<HashSet<TypeId>>,
next_connection_id: AtomicU32,
}
+#[derive(Clone)]
+struct Connection {
+ outgoing_tx: mpsc::Sender<proto::Envelope>,
+ next_message_id: Arc<AtomicU32>,
+ response_channels: ResponseChannels,
+}
+
+pub struct ConnectionHandler<Conn> {
+ peer: Arc<Peer>,
+ connection_id: ConnectionId,
+ response_channels: ResponseChannels,
+ outgoing_rx: mpsc::Receiver<proto::Envelope>,
+ reader: MessageStream<Conn>,
+ writer: MessageStream<Conn>,
+}
+
+type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
+
impl Peer {
pub fn new() -> Arc<Self> {
Arc::new(Self {
connections: Default::default(),
- connection_close_barriers: Default::default(),
message_handlers: Default::default(),
handler_types: Default::default(),
next_connection_id: Default::default(),
@@ -127,7 +129,10 @@ impl Peer {
rx
}
- pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
+ pub async fn add_connection<Conn>(
+ self: &Arc<Self>,
+ conn: Conn,
+ ) -> (ConnectionId, ConnectionHandler<Conn>)
where
Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
@@ -135,120 +140,37 @@ impl Peer {
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
);
- self.connections.write().await.insert(
+ let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
+ let connection = Connection {
+ outgoing_tx,
+ next_message_id: Default::default(),
+ response_channels: Default::default(),
+ };
+ let handler = ConnectionHandler {
+ peer: self.clone(),
connection_id,
- Arc::new(Connection {
- reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
- writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
- response_channels: Default::default(),
- next_message_id: Default::default(),
- }),
- );
- connection_id
+ response_channels: connection.response_channels.clone(),
+ outgoing_rx,
+ reader: MessageStream::new(conn.clone()),
+ writer: MessageStream::new(conn),
+ };
+ self.connections
+ .write()
+ .await
+ .insert(connection_id, connection);
+ (connection_id, handler)
}
pub async fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().await.remove(&connection_id);
- self.connection_close_barriers
- .write()
- .await
- .remove(&connection_id);
}
pub async fn reset(&self) {
self.connections.write().await.clear();
- self.connection_close_barriers.write().await.clear();
self.handler_types.lock().await.clear();
self.message_handlers.write().await.clear();
}
- pub fn handle_messages(
- self: &Arc<Self>,
- connection_id: ConnectionId,
- ) -> impl Future<Output = Result<()>> + 'static {
- let (close_tx, mut close_rx) = barrier::channel();
- let this = self.clone();
- async move {
- this.connection_close_barriers
- .write()
- .await
- .insert(connection_id, close_tx);
- let connection = this.connection(connection_id).await?;
- let closed = close_rx.recv();
- futures::pin_mut!(closed);
-
- loop {
- let mut reader = connection.reader.lock().await;
- let read_message = reader.read_message();
- futures::pin_mut!(read_message);
-
- match futures::future::select(read_message, &mut closed).await {
- Either::Left((Ok(incoming), _)) => {
- if let Some(responding_to) = incoming.responding_to {
- let channel = connection
- .response_channels
- .lock()
- .await
- .remove(&responding_to);
- if let Some(mut tx) = channel {
- tx.send(incoming).await.ok();
- } else {
- log::warn!(
- "received RPC response to unknown request {}",
- responding_to
- );
- }
- } else {
- let mut envelope = Some(incoming);
- let mut handler_index = None;
- let mut handler_was_dropped = false;
- for (i, handler) in
- this.message_handlers.read().await.iter().enumerate()
- {
- if let Some(future) = handler(&mut envelope, connection_id) {
- handler_was_dropped = future.await;
- handler_index = Some(i);
- break;
- }
- }
-
- if let Some(handler_index) = handler_index {
- if handler_was_dropped {
- drop(this.message_handlers.write().await.remove(handler_index));
- }
- } else {
- log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
- }
- }
- }
- Either::Left((Err(error), _)) => {
- log::warn!("received invalid RPC message: {}", error);
- Err(error)?;
- }
- Either::Right(_) => return Ok(()),
- }
- }
- }
- }
-
- pub async fn receive<M: EnvelopedMessage>(
- self: &Arc<Self>,
- connection_id: ConnectionId,
- ) -> Result<TypedEnvelope<M>> {
- let connection = self.connection(connection_id).await?;
- let envelope = connection.reader.lock().await.read_message().await?;
- let original_sender_id = envelope.original_sender_id;
- let message_id = envelope.id;
- let payload =
- M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
- Ok(TypedEnvelope {
- sender_id: connection_id,
- original_sender_id: original_sender_id.map(PeerId),
- message_id,
- payload,
- })
- }
-
pub fn request<T: RequestMessage>(
self: &Arc<Self>,
receiver_id: ConnectionId,
@@ -273,9 +195,9 @@ impl Peer {
request: T,
) -> impl Future<Output = Result<T::Response>> {
let this = self.clone();
- let (tx, mut rx) = oneshot::channel();
+ let (tx, mut rx) = mpsc::channel(1);
async move {
- let connection = this.connection(receiver_id).await?;
+ let mut connection = this.connection(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
@@ -285,19 +207,13 @@ impl Peer {
.await
.insert(message_id, tx);
connection
- .writer
- .lock()
- .await
- .write_message(&request.into_envelope(
- message_id,
- None,
- original_sender_id.map(|id| id.0),
- ))
+ .outgoing_tx
+ .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
.await?;
let response = rx
.recv()
.await
- .expect("response channel was unexpectedly dropped");
+ .ok_or_else(|| anyhow!("connection was closed"))?;
T::Response::from_envelope(response)
.ok_or_else(|| anyhow!("received response of the wrong type"))
}
@@ -305,20 +221,18 @@ impl Peer {
pub fn send<T: EnvelopedMessage>(
self: &Arc<Self>,
- connection_id: ConnectionId,
+ receiver_id: ConnectionId,
message: T,
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
- let connection = this.connection(connection_id).await?;
+ let mut connection = this.connection(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
- .writer
- .lock()
- .await
- .write_message(&message.into_envelope(message_id, None, None))
+ .outgoing_tx
+ .send(message.into_envelope(message_id, None, None))
.await?;
Ok(())
}
@@ -332,15 +246,13 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
- let connection = this.connection(receiver_id).await?;
+ let mut connection = this.connection(receiver_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
- .writer
- .lock()
- .await
- .write_message(&message.into_envelope(message_id, None, Some(sender_id.0)))
+ .outgoing_tx
+ .send(message.into_envelope(message_id, None, Some(sender_id.0)))
.await?;
Ok(())
}
@@ -353,28 +265,114 @@ impl Peer {
) -> impl Future<Output = Result<()>> {
let this = self.clone();
async move {
- let connection = this.connection(receipt.sender_id).await?;
+ let mut connection = this.connection(receipt.sender_id).await?;
let message_id = connection
.next_message_id
.fetch_add(1, atomic::Ordering::SeqCst);
connection
- .writer
- .lock()
- .await
- .write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
+ .outgoing_tx
+ .send(response.into_envelope(message_id, Some(receipt.message_id), None))
.await?;
Ok(())
}
}
- async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
- Ok(self
- .connections
- .read()
- .await
- .get(&id)
- .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
- .clone())
+ fn connection(
+ self: &Arc<Self>,
+ connection_id: ConnectionId,
+ ) -> impl Future<Output = Result<Connection>> {
+ let this = self.clone();
+ async move {
+ let connections = this.connections.read().await;
+ let connection = connections
+ .get(&connection_id)
+ .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
+ Ok(connection.clone())
+ }
+ }
+}
+
+impl<Conn> ConnectionHandler<Conn>
+where
+ Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
+{
+ pub async fn run(mut self) -> Result<()> {
+ loop {
+ let read_message = self.reader.read_message().fuse();
+ futures::pin_mut!(read_message);
+ loop {
+ futures::select! {
+ incoming = read_message => match incoming {
+ Ok(incoming) => {
+ Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
+ break;
+ }
+ Err(error) => {
+ self.response_channels.lock().await.clear();
+ Err(error).context("received invalid RPC message")?;
+ }
+ },
+ outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
+ Some(outgoing) => {
+ if let Err(result) = self.writer.write_message(&outgoing).await {
+ self.response_channels.lock().await.clear();
+ Err(result).context("failed to write RPC message")?;
+ }
+ }
+ None => return Ok(()),
+ }
+ }
+ }
+ }
+ }
+
+ pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
+ let envelope = self.reader.read_message().await?;
+ let original_sender_id = envelope.original_sender_id;
+ let message_id = envelope.id;
+ let payload =
+ M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
+ Ok(TypedEnvelope {
+ sender_id: self.connection_id,
+ original_sender_id: original_sender_id.map(PeerId),
+ message_id,
+ payload,
+ })
+ }
+
+ async fn handle_incoming_message(
+ message: proto::Envelope,
+ peer: &Arc<Peer>,
+ connection_id: ConnectionId,
+ response_channels: &ResponseChannels,
+ ) {
+ if let Some(responding_to) = message.responding_to {
+ let channel = response_channels.lock().await.remove(&responding_to);
+ if let Some(mut tx) = channel {
+ tx.send(message).await.ok();
+ } else {
+ log::warn!("received RPC response to unknown request {}", responding_to);
+ }
+ } else {
+ let mut envelope = Some(message);
+ let mut handler_index = None;
+ let mut handler_was_dropped = false;
+ for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
+ if let Some(future) = handler(&mut envelope, connection_id) {
+ handler_was_dropped = future.await;
+ handler_index = Some(i);
+ break;
+ }
+ }
+
+ if let Some(handler_index) = handler_index {
+ if handler_was_dropped {
+ drop(peer.message_handlers.write().await.remove(handler_index));
+ }
+ } else {
+ log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
+ }
+ }
}
}
@@ -412,22 +410,22 @@ mod tests {
let server = Peer::new();
let client1 = Peer::new();
let client2 = Peer::new();
- let client1_conn_id = client1
+ let (client1_conn_id, task1) = client1
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await;
- let client2_conn_id = client2
+ let (client2_conn_id, task2) = client2
.add_connection(UnixStream::connect(&socket_path).await.unwrap())
.await;
- let server_conn_id1 = server
+ let (_, task3) = server
.add_connection(listener.accept().await.unwrap().0)
.await;
- let server_conn_id2 = server
+ let (_, task4) = server
.add_connection(listener.accept().await.unwrap().0)
.await;
- smol::spawn(client1.handle_messages(client1_conn_id)).detach();
- smol::spawn(client2.handle_messages(client2_conn_id)).detach();
- smol::spawn(server.handle_messages(server_conn_id1)).detach();
- smol::spawn(server.handle_messages(server_conn_id2)).detach();
+ smol::spawn(task1.run()).detach();
+ smol::spawn(task2.run()).detach();
+ smol::spawn(task3.run()).detach();
+ smol::spawn(task4.run()).detach();
// define the expected requests and responses
let request1 = proto::Auth {
@@ -548,12 +546,11 @@ mod tests {
let (mut server_conn, _) = listener.accept().await.unwrap();
let client = Peer::new();
- let connection_id = client.add_connection(client_conn).await;
+ let (connection_id, handler) = client.add_connection(client_conn).await;
let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
- barrier::channel();
- let handle_messages = client.handle_messages(connection_id);
+ postage::barrier::channel();
smol::spawn(async move {
- handle_messages.await.ok();
+ handler.run().await.ok();
incoming_messages_ended_tx.send(()).await.unwrap();
})
.detach();
@@ -576,8 +573,8 @@ mod tests {
client_conn.close().await.unwrap();
let client = Peer::new();
- let connection_id = client.add_connection(client_conn).await;
- smol::spawn(client.handle_messages(connection_id)).detach();
+ let (connection_id, handler) = client.add_connection(client_conn).await;
+ smol::spawn(handler.run()).detach();
let err = client
.request(
@@ -589,10 +586,7 @@ mod tests {
)
.await
.unwrap_err();
- assert_eq!(
- err.downcast_ref::<io::Error>().unwrap().kind(),
- io::ErrorKind::BrokenPipe
- );
+ assert_eq!(err.to_string(), "connection was closed");
});
}
}