1#![allow(non_snake_case)]
2
3pub use ::proto::*;
4
5use async_tungstenite::tungstenite::Message as WebSocketMessage;
6use futures::{SinkExt as _, StreamExt as _};
7use proto::Message as _;
8use std::time::Instant;
9use std::{fmt::Debug, io};
10use zstd::zstd_safe::WriteBuf;
11
12const KIB: usize = 1024;
13const MIB: usize = KIB * 1024;
14const MAX_BUFFER_LEN: usize = MIB;
15
16/// A stream of protobuf messages.
17pub struct MessageStream<S> {
18 stream: S,
19 encoding_buffer: Vec<u8>,
20}
21
22#[derive(Debug)]
23pub enum Message {
24 Envelope(Envelope),
25 Ping,
26 Pong,
27}
28
29impl<S> MessageStream<S> {
30 pub fn new(stream: S) -> Self {
31 Self {
32 stream,
33 encoding_buffer: Vec::new(),
34 }
35 }
36}
37
38impl<S> MessageStream<S>
39where
40 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
41{
42 pub async fn write(&mut self, message: Message) -> anyhow::Result<()> {
43 #[cfg(any(test, feature = "test-support"))]
44 const COMPRESSION_LEVEL: i32 = -7;
45
46 #[cfg(not(any(test, feature = "test-support")))]
47 const COMPRESSION_LEVEL: i32 = 4;
48
49 match message {
50 Message::Envelope(message) => {
51 self.encoding_buffer.reserve(message.encoded_len());
52 message
53 .encode(&mut self.encoding_buffer)
54 .map_err(io::Error::from)?;
55 let buffer =
56 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
57 .unwrap();
58
59 self.encoding_buffer.clear();
60 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
61 self.stream
62 .send(WebSocketMessage::Binary(buffer.into()))
63 .await?;
64 }
65 Message::Ping => {
66 self.stream
67 .send(WebSocketMessage::Ping(Default::default()))
68 .await?;
69 }
70 Message::Pong => {
71 self.stream
72 .send(WebSocketMessage::Pong(Default::default()))
73 .await?;
74 }
75 }
76
77 Ok(())
78 }
79}
80
81impl<S> MessageStream<S>
82where
83 S: futures::Stream<Item = anyhow::Result<WebSocketMessage>> + Unpin,
84{
85 pub async fn read(&mut self) -> anyhow::Result<(Message, Instant)> {
86 while let Some(bytes) = self.stream.next().await {
87 let received_at = Instant::now();
88 match bytes? {
89 WebSocketMessage::Binary(bytes) => {
90 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer)?;
91 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
92 .map_err(io::Error::from)?;
93
94 self.encoding_buffer.clear();
95 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
96 return Ok((Message::Envelope(envelope), received_at));
97 }
98 WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)),
99 WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)),
100 WebSocketMessage::Close(_) => break,
101 _ => {}
102 }
103 }
104 anyhow::bail!("connection closed");
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[gpui::test]
113 async fn test_buffer_size() {
114 let (tx, rx) = futures::channel::mpsc::unbounded();
115 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow::anyhow!("")));
116 sink.write(Message::Envelope(Envelope {
117 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
118 root_name: "abcdefg".repeat(10),
119 ..Default::default()
120 })),
121 ..Default::default()
122 }))
123 .await
124 .unwrap();
125 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
126 sink.write(Message::Envelope(Envelope {
127 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
128 root_name: "abcdefg".repeat(1000000),
129 ..Default::default()
130 })),
131 ..Default::default()
132 }))
133 .await
134 .unwrap();
135 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
136
137 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
138 stream.read().await.unwrap();
139 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
140 stream.read().await.unwrap();
141 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
142 }
143}