proto.rs

  1#![allow(non_snake_case)]
  2
  3use anyhow::anyhow;
  4use async_tungstenite::tungstenite::Message as WebSocketMessage;
  5use futures::{SinkExt as _, StreamExt as _};
  6pub use proto::{Message as _, *};
  7use std::time::Instant;
  8use std::{fmt::Debug, io};
  9
 10const KIB: usize = 1024;
 11const MIB: usize = KIB * 1024;
 12const MAX_BUFFER_LEN: usize = MIB;
 13
 14/// A stream of protobuf messages.
 15pub struct MessageStream<S> {
 16    stream: S,
 17    encoding_buffer: Vec<u8>,
 18}
 19
 20#[allow(clippy::large_enum_variant)]
 21#[derive(Debug)]
 22pub enum Message {
 23    Envelope(Envelope),
 24    Ping,
 25    Pong,
 26}
 27
 28impl<S> MessageStream<S> {
 29    pub fn new(stream: S) -> Self {
 30        Self {
 31            stream,
 32            encoding_buffer: Vec::new(),
 33        }
 34    }
 35
 36    pub fn inner_mut(&mut self) -> &mut S {
 37        &mut self.stream
 38    }
 39}
 40
 41impl<S> MessageStream<S>
 42where
 43    S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
 44{
 45    pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
 46        #[cfg(any(test, feature = "test-support"))]
 47        const COMPRESSION_LEVEL: i32 = -7;
 48
 49        #[cfg(not(any(test, feature = "test-support")))]
 50        const COMPRESSION_LEVEL: i32 = 4;
 51
 52        match message {
 53            Message::Envelope(message) => {
 54                self.encoding_buffer.reserve(message.encoded_len());
 55                message
 56                    .encode(&mut self.encoding_buffer)
 57                    .map_err(io::Error::from)?;
 58                let buffer =
 59                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
 60                        .unwrap();
 61
 62                self.encoding_buffer.clear();
 63                self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
 64                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
 65            }
 66            Message::Ping => {
 67                self.stream
 68                    .send(WebSocketMessage::Ping(Default::default()))
 69                    .await?;
 70            }
 71            Message::Pong => {
 72                self.stream
 73                    .send(WebSocketMessage::Pong(Default::default()))
 74                    .await?;
 75            }
 76        }
 77
 78        Ok(())
 79    }
 80}
 81
 82impl<S> MessageStream<S>
 83where
 84    S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
 85{
 86    pub async fn read(&mut self) -> Result<(Message, Instant), anyhow::Error> {
 87        while let Some(bytes) = self.stream.next().await {
 88            let received_at = Instant::now();
 89            match bytes? {
 90                WebSocketMessage::Binary(bytes) => {
 91                    zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer)?;
 92                    let envelope = Envelope::decode(self.encoding_buffer.as_slice())
 93                        .map_err(io::Error::from)?;
 94
 95                    self.encoding_buffer.clear();
 96                    self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
 97                    return Ok((Message::Envelope(envelope), received_at));
 98                }
 99                WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)),
100                WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)),
101                WebSocketMessage::Close(_) => break,
102                _ => {}
103            }
104        }
105        Err(anyhow!("connection closed"))
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[gpui::test]
114    async fn test_buffer_size() {
115        let (tx, rx) = futures::channel::mpsc::unbounded();
116        let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
117        sink.write(Message::Envelope(Envelope {
118            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
119                root_name: "abcdefg".repeat(10),
120                ..Default::default()
121            })),
122            ..Default::default()
123        }))
124        .await
125        .unwrap();
126        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
127        sink.write(Message::Envelope(Envelope {
128            payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
129                root_name: "abcdefg".repeat(1000000),
130                ..Default::default()
131            })),
132            ..Default::default()
133        }))
134        .await
135        .unwrap();
136        assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
137
138        let mut stream = MessageStream::new(rx.map(anyhow::Ok));
139        stream.read().await.unwrap();
140        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
141        stream.read().await.unwrap();
142        assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
143    }
144}