@@ -94,8 +94,7 @@ pub struct ConnectionState {
Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
}
-const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2);
-const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
+const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
impl Peer {
pub fn new() -> Arc<Self> {
@@ -144,13 +143,17 @@ impl Peer {
});
loop {
- let read_message = reader.read_message().fuse();
+ let read_message = reader.read().fuse();
futures::pin_mut!(read_message);
+ let read_timeout = create_timer(2 * KEEPALIVE_INTERVAL).fuse();
+ futures::pin_mut!(read_timeout);
+
loop {
futures::select_biased! {
outgoing = outgoing_rx.next().fuse() => match outgoing {
Some(outgoing) => {
- if let Some(result) = writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
+ let outgoing = proto::Message::Envelope(outgoing);
+ if let Some(result) = writer.write(outgoing).timeout(2 * KEEPALIVE_INTERVAL).await {
result.context("failed to write RPC message")?;
} else {
Err(anyhow!("timed out writing message"))?;
@@ -159,19 +162,25 @@ impl Peer {
None => return Ok(()),
},
incoming = read_message => {
- let incoming = incoming.context("received invalid rpc message")?;
- if incoming_tx.send(incoming).await.is_err() {
- return Ok(());
+ let incoming = incoming.context("received invalid RPC message")?;
+ if let proto::Message::Envelope(incoming) = incoming {
+ if incoming_tx.send(incoming).await.is_err() {
+ return Ok(());
+ }
}
+
break;
},
_ = create_timer(KEEPALIVE_INTERVAL).fuse() => {
- if let Some(result) = writer.ping().timeout(WRITE_TIMEOUT).await {
+ if let Some(result) = writer.write(proto::Message::Ping).timeout(2 * KEEPALIVE_INTERVAL).await {
result.context("failed to send websocket ping")?;
} else {
Err(anyhow!("timed out sending websocket ping"))?;
}
}
+ _ = read_timeout => {
+ Err(anyhow!("timed out reading message"))?
+ }
}
}
}
@@ -2,7 +2,7 @@ use super::{ConnectionId, PeerId, TypedEnvelope};
use anyhow::Result;
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use futures::{SinkExt as _, StreamExt as _};
-use prost::Message;
+use prost::Message as _;
use std::any::{Any, TypeId};
use std::{
io,
@@ -283,6 +283,12 @@ pub struct MessageStream<S> {
encoding_buffer: Vec<u8>,
}
+pub enum Message {
+ Envelope(Envelope),
+ Ping,
+ Pong,
+}
+
impl<S> MessageStream<S> {
pub fn new(stream: S) -> Self {
Self {
@@ -300,29 +306,37 @@ impl<S> MessageStream<S>
where
S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
{
- /// Write a given protobuf message to the stream.
- pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
+ pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
#[cfg(any(test, feature = "test-support"))]
const COMPRESSION_LEVEL: i32 = -7;
#[cfg(not(any(test, feature = "test-support")))]
const COMPRESSION_LEVEL: i32 = 4;
- self.encoding_buffer.resize(message.encoded_len(), 0);
- self.encoding_buffer.clear();
- message
- .encode(&mut self.encoding_buffer)
- .map_err(|err| io::Error::from(err))?;
- let buffer =
- zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL).unwrap();
- self.stream.send(WebSocketMessage::Binary(buffer)).await?;
- Ok(())
- }
+ match message {
+ Message::Envelope(message) => {
+ self.encoding_buffer.resize(message.encoded_len(), 0);
+ self.encoding_buffer.clear();
+ message
+ .encode(&mut self.encoding_buffer)
+ .map_err(|err| io::Error::from(err))?;
+ let buffer =
+ zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
+ .unwrap();
+ self.stream.send(WebSocketMessage::Binary(buffer)).await?;
+ }
+ Message::Ping => {
+ self.stream
+ .send(WebSocketMessage::Ping(Default::default()))
+ .await?;
+ }
+ Message::Pong => {
+ self.stream
+ .send(WebSocketMessage::Pong(Default::default()))
+ .await?;
+ }
+ }
- pub async fn ping(&mut self) -> Result<(), WebSocketError> {
- self.stream
- .send(WebSocketMessage::Ping(Default::default()))
- .await?;
Ok(())
}
}
@@ -331,8 +345,7 @@ impl<S> MessageStream<S>
where
S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
{
- /// Read a protobuf message of the given type from the stream.
- pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
+ pub async fn read(&mut self) -> Result<Message, WebSocketError> {
while let Some(bytes) = self.stream.next().await {
match bytes? {
WebSocketMessage::Binary(bytes) => {
@@ -340,8 +353,10 @@ where
zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
let envelope = Envelope::decode(self.encoding_buffer.as_slice())
.map_err(io::Error::from)?;
- return Ok(envelope);
+ return Ok(Message::Envelope(envelope));
}
+ WebSocketMessage::Ping(_) => return Ok(Message::Ping),
+ WebSocketMessage::Pong(_) => return Ok(Message::Pong),
WebSocketMessage::Close(_) => break,
_ => {}
}