message_stream.rs

  1#![allow(non_snake_case)]
  2
  3pub use ::proto::*;
  4
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use futures::{SinkExt as _, StreamExt as _};
  7use proto::Message as _;
  8use std::time::Instant;
  9use std::{fmt::Debug, io};
 10
 11const KIB: usize = 1024;
 12const MIB: usize = KIB * 1024;
 13const MAX_BUFFER_LEN: usize = MIB;
 14
 15/// A stream of protobuf messages.
 16pub struct MessageStream<S> {
 17    stream: S,
 18    encoding_buffer: Vec<u8>,
 19}
 20
 21#[derive(Debug)]
 22pub enum Message {
 23    Envelope(Envelope),
 24    Ping,
 25    Pong,
 26}
 27
 28impl<S> MessageStream<S> {
 29    pub fn new(stream: S) -> Self {
 30        Self {
 31            stream,
 32            encoding_buffer: Vec::new(),
 33        }
 34    }
 35}
 36
 37impl<S> MessageStream<S>
 38where
 39    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
 40{
 41    pub async fn write(&mut self, message: Message) -> anyhow::Result<()> {
 42        #[cfg(any(test, feature = "test-support"))]
 43        const COMPRESSION_LEVEL: i32 = -7;
 44
 45        #[cfg(not(any(test, feature = "test-support")))]
 46        const COMPRESSION_LEVEL: i32 = 4;
 47
 48        match message {
 49            Message::Envelope(message) => {
 50                self.encoding_buffer.reserve(message.encoded_len());
 51                message
 52                    .encode(&mut self.encoding_buffer)
 53                    .map_err(io::Error::from)?;
 54                let buffer =
 55                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
 56                        .unwrap();
 57
 58                self.encoding_buffer.clear();
 59                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
 60                self.stream
 61                    .send(WebSocketMessage::Binary(buffer.into()))
 62                    .await?;
 63            }
 64            Message::Ping => {
 65                self.stream
 66                    .send(WebSocketMessage::Ping(Default::default()))
 67                    .await?;
 68            }
 69            Message::Pong => {
 70                self.stream
 71                    .send(WebSocketMessage::Pong(Default::default()))
 72                    .await?;
 73            }
 74        }
 75
 76        Ok(())
 77    }
 78}
 79
 80impl<S> MessageStream<S>
 81where
 82    S: futures::Stream<Item = anyhow::Result<WebSocketMessage>> + Unpin,
 83{
 84    pub async fn read(&mut self) -> anyhow::Result<(Message, Instant)> {
 85        while let Some(bytes) = self.stream.next().await {
 86            let received_at = Instant::now();
 87            match bytes? {
 88                WebSocketMessage::Binary(bytes) => {
 89                    zstd::stream::copy_decode(
 90                        zstd::zstd_safe::WriteBuf::as_slice(&*bytes),
 91                        &mut self.encoding_buffer,
 92                    )?;
 93                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
 94                        .map_err(io::Error::from)?;
 95
 96                    self.encoding_buffer.clear();
 97                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
 98                    return Ok((Message::Envelope(envelope), received_at));
 99                }
100                WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)),
101                WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)),
102                WebSocketMessage::Close(_) => break,
103                _ => {}
104            }
105        }
106        anyhow::bail!("connection closed");
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[gpui::test]
115    async fn test_buffer_size() {
116        let (tx, rx) = futures::channel::mpsc::unbounded();
117        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow::anyhow!("")));
118        sink.write(Message::Envelope(Envelope {
119            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
120                root_name: "abcdefg".repeat(10),
121                ..Default::default()
122            })),
123            ..Default::default()
124        }))
125        .await
126        .unwrap();
127        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
128        sink.write(Message::Envelope(Envelope {
129            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
130                root_name: "abcdefg".repeat(1000000),
131                ..Default::default()
132            })),
133            ..Default::default()
134        }))
135        .await
136        .unwrap();
137        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
138
139        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
140        stream.read().await.unwrap();
141        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
142        stream.read().await.unwrap();
143        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
144    }
145}