message_stream.rs

  1#![allow(non_snake_case)]
  2
  3pub use ::proto::*;
  4
  5use anyhow::anyhow;
  6use async_tungstenite::tungstenite::Message as WebSocketMessage;
  7use futures::{SinkExt as _, StreamExt as _};
  8use proto::Message as _;
  9use std::time::Instant;
 10use std::{fmt::Debug, io};
 11use zstd::zstd_safe::WriteBuf;
 12
 13const KIB: usize = 1024;
 14const MIB: usize = KIB * 1024;
 15const MAX_BUFFER_LEN: usize = MIB;
 16
 17/// A stream of protobuf messages.
 18pub struct MessageStream<S> {
 19    stream: S,
 20    encoding_buffer: Vec<u8>,
 21}
 22
 23#[allow(clippy::large_enum_variant)]
 24#[derive(Debug)]
 25pub enum Message {
 26    Envelope(Envelope),
 27    Ping,
 28    Pong,
 29}
 30
 31impl<S> MessageStream<S> {
 32    pub fn new(stream: S) -> Self {
 33        Self {
 34            stream,
 35            encoding_buffer: Vec::new(),
 36        }
 37    }
 38}
 39
 40impl<S> MessageStream<S>
 41where
 42    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
 43{
 44    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
 45        #[cfg(any(test, feature = "test-support"))]
 46        const COMPRESSION_LEVEL: i32 = -7;
 47
 48        #[cfg(not(any(test, feature = "test-support")))]
 49        const COMPRESSION_LEVEL: i32 = 4;
 50
 51        match message {
 52            Message::Envelope(message) => {
 53                self.encoding_buffer.reserve(message.encoded_len());
 54                message
 55                    .encode(&mut self.encoding_buffer)
 56                    .map_err(io::Error::from)?;
 57                let buffer =
 58                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
 59                        .unwrap();
 60
 61                self.encoding_buffer.clear();
 62                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
 63                self.stream
 64                    .send(WebSocketMessage::Binary(buffer.into()))
 65                    .await?;
 66            }
 67            Message::Ping => {
 68                self.stream
 69                    .send(WebSocketMessage::Ping(Default::default()))
 70                    .await?;
 71            }
 72            Message::Pong => {
 73                self.stream
 74                    .send(WebSocketMessage::Pong(Default::default()))
 75                    .await?;
 76            }
 77        }
 78
 79        Ok(())
 80    }
 81}
 82
 83impl<S> MessageStream<S>
 84where
 85    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
 86{
 87    pub async fn read(&mut self) -> Result<(Message, Instant), anyhow::Error> {
 88        while let Some(bytes) = self.stream.next().await {
 89            let received_at = Instant::now();
 90            match bytes? {
 91                WebSocketMessage::Binary(bytes) => {
 92                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer)?;
 93                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
 94                        .map_err(io::Error::from)?;
 95
 96                    self.encoding_buffer.clear();
 97                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
 98                    return Ok((Message::Envelope(envelope), received_at));
 99                }
100                WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)),
101                WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)),
102                WebSocketMessage::Close(_) => break,
103                _ => {}
104            }
105        }
106        Err(anyhow!("connection closed"))
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[gpui::test]
115    async fn test_buffer_size() {
116        let (tx, rx) = futures::channel::mpsc::unbounded();
117        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
118        sink.write(Message::Envelope(Envelope {
119            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
120                root_name: "abcdefg".repeat(10),
121                ..Default::default()
122            })),
123            ..Default::default()
124        }))
125        .await
126        .unwrap();
127        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
128        sink.write(Message::Envelope(Envelope {
129            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
130                root_name: "abcdefg".repeat(1000000),
131                ..Default::default()
132            })),
133            ..Default::default()
134        }))
135        .await
136        .unwrap();
137        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
138
139        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
140        stream.read().await.unwrap();
141        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
142        stream.read().await.unwrap();
143        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
144    }
145}