diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index e9b8d50e6808d1bc23b6481b581e9750286a5c97..d5f790833e5d01b48233ae6925bda4d3a980b1c4 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -94,8 +94,7 @@ pub struct ConnectionState { Arc>>>>, } -const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2); -const WRITE_TIMEOUT: Duration = Duration::from_secs(10); +const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1); impl Peer { pub fn new() -> Arc { @@ -144,13 +143,17 @@ impl Peer { }); loop { - let read_message = reader.read_message().fuse(); + let read_message = reader.read().fuse(); futures::pin_mut!(read_message); + let read_timeout = create_timer(2 * KEEPALIVE_INTERVAL).fuse(); + futures::pin_mut!(read_timeout); + loop { futures::select_biased! { outgoing = outgoing_rx.next().fuse() => match outgoing { Some(outgoing) => { - if let Some(result) = writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await { + let outgoing = proto::Message::Envelope(outgoing); + if let Some(result) = writer.write(outgoing).timeout(2 * KEEPALIVE_INTERVAL).await { result.context("failed to write RPC message")?; } else { Err(anyhow!("timed out writing message"))?; @@ -159,19 +162,25 @@ impl Peer { None => return Ok(()), }, incoming = read_message => { - let incoming = incoming.context("received invalid rpc message")?; - if incoming_tx.send(incoming).await.is_err() { - return Ok(()); + let incoming = incoming.context("received invalid RPC message")?; + if let proto::Message::Envelope(incoming) = incoming { + if incoming_tx.send(incoming).await.is_err() { + return Ok(()); + } } + break; }, _ = create_timer(KEEPALIVE_INTERVAL).fuse() => { - if let Some(result) = writer.ping().timeout(WRITE_TIMEOUT).await { + if let Some(result) = writer.write(proto::Message::Ping).timeout(2 * KEEPALIVE_INTERVAL).await { result.context("failed to send websocket ping")?; } else { Err(anyhow!("timed out sending websocket ping"))?; } } + _ = read_timeout => { + Err(anyhow!("timed out reading message"))? + } } } } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 3d7557842a3b50cca21acaff3304cc93a8ec43e1..a1cb3dbc2eecc8d70bb25e6992ebc9bcb6d96560 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -2,7 +2,7 @@ use super::{ConnectionId, PeerId, TypedEnvelope}; use anyhow::Result; use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage}; use futures::{SinkExt as _, StreamExt as _}; -use prost::Message; +use prost::Message as _; use std::any::{Any, TypeId}; use std::{ io, @@ -283,6 +283,12 @@ pub struct MessageStream { encoding_buffer: Vec, } +pub enum Message { + Envelope(Envelope), + Ping, + Pong, +} + impl MessageStream { pub fn new(stream: S) -> Self { Self { @@ -300,29 +306,37 @@ impl MessageStream where S: futures::Sink + Unpin, { - /// Write a given protobuf message to the stream. - pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> { + pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> { #[cfg(any(test, feature = "test-support"))] const COMPRESSION_LEVEL: i32 = -7; #[cfg(not(any(test, feature = "test-support")))] const COMPRESSION_LEVEL: i32 = 4; - self.encoding_buffer.resize(message.encoded_len(), 0); - self.encoding_buffer.clear(); - message - .encode(&mut self.encoding_buffer) - .map_err(|err| io::Error::from(err))?; - let buffer = - zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL).unwrap(); - self.stream.send(WebSocketMessage::Binary(buffer)).await?; - Ok(()) - } + match message { + Message::Envelope(message) => { + self.encoding_buffer.resize(message.encoded_len(), 0); + self.encoding_buffer.clear(); + message + .encode(&mut self.encoding_buffer) + .map_err(|err| io::Error::from(err))?; + let buffer = + zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL) + .unwrap(); + self.stream.send(WebSocketMessage::Binary(buffer)).await?; + } + Message::Ping => { + self.stream + .send(WebSocketMessage::Ping(Default::default())) + .await?; + } + Message::Pong => { + self.stream + .send(WebSocketMessage::Pong(Default::default())) + .await?; + } + } - pub async fn ping(&mut self) -> Result<(), WebSocketError> { - self.stream - .send(WebSocketMessage::Ping(Default::default())) - .await?; Ok(()) } } @@ -331,8 +345,7 @@ impl MessageStream where S: futures::Stream> + Unpin, { - /// Read a protobuf message of the given type from the stream. - pub async fn read_message(&mut self) -> Result { + pub async fn read(&mut self) -> Result { while let Some(bytes) = self.stream.next().await { match bytes? { WebSocketMessage::Binary(bytes) => { @@ -340,8 +353,10 @@ where zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap(); let envelope = Envelope::decode(self.encoding_buffer.as_slice()) .map_err(io::Error::from)?; - return Ok(envelope); + return Ok(Message::Envelope(envelope)); } + WebSocketMessage::Ping(_) => return Ok(Message::Ping), + WebSocketMessage::Pong(_) => return Ok(Message::Pong), WebSocketMessage::Close(_) => break, _ => {} }