message_stream.rs

  1#![allow(non_snake_case)]
  2
  3pub use ::proto::*;
  4
  5use async_tungstenite::tungstenite::Message as WebSocketMessage;
  6use futures::{SinkExt as _, StreamExt as _};
  7use proto::Message as _;
  8use std::time::Instant;
  9use std::{fmt::Debug, io};
 10use zstd::zstd_safe::WriteBuf;
 11
 12const KIB: usize = 1024;
 13const MIB: usize = KIB * 1024;
 14const MAX_BUFFER_LEN: usize = MIB;
 15
 16/// A stream of protobuf messages.
 17pub struct MessageStream<S> {
 18    stream: S,
 19    encoding_buffer: Vec<u8>,
 20}
 21
 22#[derive(Debug)]
 23pub enum Message {
 24    Envelope(Envelope),
 25    Ping,
 26    Pong,
 27}
 28
 29impl<S> MessageStream<S> {
 30    pub const fn new(stream: S) -> Self {
 31        Self {
 32            stream,
 33            encoding_buffer: Vec::new(),
 34        }
 35    }
 36}
 37
 38impl<S> MessageStream<S>
 39where
 40    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
 41{
 42    pub async fn write(&mut self, message: Message) -> anyhow::Result<()> {
 43        #[cfg(any(test, feature = "test-support"))]
 44        const COMPRESSION_LEVEL: i32 = -7;
 45
 46        #[cfg(not(any(test, feature = "test-support")))]
 47        const COMPRESSION_LEVEL: i32 = 4;
 48
 49        match message {
 50            Message::Envelope(message) => {
 51                self.encoding_buffer.reserve(message.encoded_len());
 52                message
 53                    .encode(&mut self.encoding_buffer)
 54                    .map_err(io::Error::from)?;
 55                let buffer =
 56                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
 57                        .unwrap();
 58
 59                self.encoding_buffer.clear();
 60                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
 61                self.stream
 62                    .send(WebSocketMessage::Binary(buffer.into()))
 63                    .await?;
 64            }
 65            Message::Ping => {
 66                self.stream
 67                    .send(WebSocketMessage::Ping(Default::default()))
 68                    .await?;
 69            }
 70            Message::Pong => {
 71                self.stream
 72                    .send(WebSocketMessage::Pong(Default::default()))
 73                    .await?;
 74            }
 75        }
 76
 77        Ok(())
 78    }
 79}
 80
 81impl<S> MessageStream<S>
 82where
 83    S: futures::Stream<Item = anyhow::Result<WebSocketMessage>> + Unpin,
 84{
 85    pub async fn read(&mut self) -> anyhow::Result<(Message, Instant)> {
 86        while let Some(bytes) = self.stream.next().await {
 87            let received_at = Instant::now();
 88            match bytes? {
 89                WebSocketMessage::Binary(bytes) => {
 90                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer)?;
 91                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
 92                        .map_err(io::Error::from)?;
 93
 94                    self.encoding_buffer.clear();
 95                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
 96                    return Ok((Message::Envelope(envelope), received_at));
 97                }
 98                WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)),
 99                WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)),
100                WebSocketMessage::Close(_) => break,
101                _ => {}
102            }
103        }
104        anyhow::bail!("connection closed");
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[gpui::test]
113    async fn test_buffer_size() {
114        let (tx, rx) = futures::channel::mpsc::unbounded();
115        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow::anyhow!("")));
116        sink.write(Message::Envelope(Envelope {
117            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
118                root_name: "abcdefg".repeat(10),
119                ..Default::default()
120            })),
121            ..Default::default()
122        }))
123        .await
124        .unwrap();
125        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
126        sink.write(Message::Envelope(Envelope {
127            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
128                root_name: "abcdefg".repeat(1000000),
129                ..Default::default()
130            })),
131            ..Default::default()
132        }))
133        .await
134        .unwrap();
135        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
136
137        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
138        stream.read().await.unwrap();
139        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
140        stream.read().await.unwrap();
141        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
142    }
143}