WIP: Introduce a read timeout

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/rpc/src/peer.rs  | 25 +++++++++++++------
crates/rpc/src/proto.rs | 55 +++++++++++++++++++++++++++---------------
2 files changed, 52 insertions(+), 28 deletions(-)

Detailed changes

crates/rpc/src/peer.rs 🔗

@@ -94,8 +94,7 @@ pub struct ConnectionState {
         Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 }
 
-const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2);
-const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
+const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
 
 impl Peer {
     pub fn new() -> Arc<Self> {
@@ -144,13 +143,17 @@ impl Peer {
             });
 
             loop {
-                let read_message = reader.read_message().fuse();
+                let read_message = reader.read().fuse();
                 futures::pin_mut!(read_message);
+                let read_timeout = create_timer(2 * KEEPALIVE_INTERVAL).fuse();
+                futures::pin_mut!(read_timeout);
+
                 loop {
                     futures::select_biased! {
                         outgoing = outgoing_rx.next().fuse() => match outgoing {
                             Some(outgoing) => {
-                                if let Some(result) = writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
+                                let outgoing = proto::Message::Envelope(outgoing);
+                                if let Some(result) = writer.write(outgoing).timeout(2 * KEEPALIVE_INTERVAL).await {
                                     result.context("failed to write RPC message")?;
                                 } else {
                                     Err(anyhow!("timed out writing message"))?;
@@ -159,19 +162,25 @@ impl Peer {
                             None => return Ok(()),
                         },
                         incoming = read_message => {
-                            let incoming = incoming.context("received invalid rpc message")?;
-                            if incoming_tx.send(incoming).await.is_err() {
-                                return Ok(());
+                            let incoming = incoming.context("received invalid RPC message")?;
+                            if let proto::Message::Envelope(incoming) = incoming {
+                                if incoming_tx.send(incoming).await.is_err() {
+                                    return Ok(());
+                                }
                             }
+
                             break;
                         },
                         _ = create_timer(KEEPALIVE_INTERVAL).fuse() => {
-                            if let Some(result) = writer.ping().timeout(WRITE_TIMEOUT).await {
+                            if let Some(result) = writer.write(proto::Message::Ping).timeout(2 * KEEPALIVE_INTERVAL).await {
                                 result.context("failed to send websocket ping")?;
                             } else {
                                 Err(anyhow!("timed out sending websocket ping"))?;
                             }
                         }
+                        _ = read_timeout => {
+                            Err(anyhow!("timed out reading message"))?
+                        }
                     }
                 }
             }

crates/rpc/src/proto.rs 🔗

@@ -2,7 +2,7 @@ use super::{ConnectionId, PeerId, TypedEnvelope};
 use anyhow::Result;
 use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
 use futures::{SinkExt as _, StreamExt as _};
-use prost::Message;
+use prost::Message as _;
 use std::any::{Any, TypeId};
 use std::{
     io,
@@ -283,6 +283,12 @@ pub struct MessageStream<S> {
     encoding_buffer: Vec<u8>,
 }
 
+pub enum Message {
+    Envelope(Envelope),
+    Ping,
+    Pong,
+}
+
 impl<S> MessageStream<S> {
     pub fn new(stream: S) -> Self {
         Self {
@@ -300,29 +306,37 @@ impl<S> MessageStream<S>
 where
     S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
 {
-    /// Write a given protobuf message to the stream.
-    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
+    pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
         #[cfg(any(test, feature = "test-support"))]
         const COMPRESSION_LEVEL: i32 = -7;
 
         #[cfg(not(any(test, feature = "test-support")))]
         const COMPRESSION_LEVEL: i32 = 4;
 
-        self.encoding_buffer.resize(message.encoded_len(), 0);
-        self.encoding_buffer.clear();
-        message
-            .encode(&mut self.encoding_buffer)
-            .map_err(|err| io::Error::from(err))?;
-        let buffer =
-            zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL).unwrap();
-        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
-        Ok(())
-    }
+        match message {
+            Message::Envelope(message) => {
+                self.encoding_buffer.resize(message.encoded_len(), 0);
+                self.encoding_buffer.clear();
+                message
+                    .encode(&mut self.encoding_buffer)
+                    .map_err(|err| io::Error::from(err))?;
+                let buffer =
+                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
+                        .unwrap();
+                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
+            }
+            Message::Ping => {
+                self.stream
+                    .send(WebSocketMessage::Ping(Default::default()))
+                    .await?;
+            }
+            Message::Pong => {
+                self.stream
+                    .send(WebSocketMessage::Pong(Default::default()))
+                    .await?;
+            }
+        }
 
-    pub async fn ping(&mut self) -> Result<(), WebSocketError> {
-        self.stream
-            .send(WebSocketMessage::Ping(Default::default()))
-            .await?;
         Ok(())
     }
 }
@@ -331,8 +345,7 @@ impl<S> MessageStream<S>
 where
     S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
 {
-    /// Read a protobuf message of the given type from the stream.
-    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
+    pub async fn read(&mut self) -> Result<Message, WebSocketError> {
         while let Some(bytes) = self.stream.next().await {
             match bytes? {
                 WebSocketMessage::Binary(bytes) => {
@@ -340,8 +353,10 @@ where
                     zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
                     let envelope = Envelope::decode(self.encoding_buffer.as_slice())
                         .map_err(io::Error::from)?;
-                    return Ok(envelope);
+                    return Ok(Message::Envelope(envelope));
                 }
+                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
+                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
                 WebSocketMessage::Close(_) => break,
                 _ => {}
             }