1#![allow(non_snake_case)]
2
3pub use ::proto::*;
4
5use anyhow::anyhow;
6use async_tungstenite::tungstenite::Message as WebSocketMessage;
7use futures::{SinkExt as _, StreamExt as _};
8use proto::Message as _;
9use std::time::Instant;
10use std::{fmt::Debug, io};
11use zstd::zstd_safe::WriteBuf;
12
13const KIB: usize = 1024;
14const MIB: usize = KIB * 1024;
15const MAX_BUFFER_LEN: usize = MIB;
16
17/// A stream of protobuf messages.
18pub struct MessageStream<S> {
19 stream: S,
20 encoding_buffer: Vec<u8>,
21}
22
23#[allow(clippy::large_enum_variant)]
24#[derive(Debug)]
25pub enum Message {
26 Envelope(Envelope),
27 Ping,
28 Pong,
29}
30
31impl<S> MessageStream<S> {
32 pub fn new(stream: S) -> Self {
33 Self {
34 stream,
35 encoding_buffer: Vec::new(),
36 }
37 }
38}
39
40impl<S> MessageStream<S>
41where
42 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
43{
44 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
45 #[cfg(any(test, feature = "test-support"))]
46 const COMPRESSION_LEVEL: i32 = -7;
47
48 #[cfg(not(any(test, feature = "test-support")))]
49 const COMPRESSION_LEVEL: i32 = 4;
50
51 match message {
52 Message::Envelope(message) => {
53 self.encoding_buffer.reserve(message.encoded_len());
54 message
55 .encode(&mut self.encoding_buffer)
56 .map_err(io::Error::from)?;
57 let buffer =
58 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
59 .unwrap();
60
61 self.encoding_buffer.clear();
62 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
63 self.stream
64 .send(WebSocketMessage::Binary(buffer.into()))
65 .await?;
66 }
67 Message::Ping => {
68 self.stream
69 .send(WebSocketMessage::Ping(Default::default()))
70 .await?;
71 }
72 Message::Pong => {
73 self.stream
74 .send(WebSocketMessage::Pong(Default::default()))
75 .await?;
76 }
77 }
78
79 Ok(())
80 }
81}
82
83impl<S> MessageStream<S>
84where
85 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
86{
87 pub async fn read(&mut self) -> Result<(Message, Instant), anyhow::Error> {
88 while let Some(bytes) = self.stream.next().await {
89 let received_at = Instant::now();
90 match bytes? {
91 WebSocketMessage::Binary(bytes) => {
92 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer)?;
93 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
94 .map_err(io::Error::from)?;
95
96 self.encoding_buffer.clear();
97 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
98 return Ok((Message::Envelope(envelope), received_at));
99 }
100 WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)),
101 WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)),
102 WebSocketMessage::Close(_) => break,
103 _ => {}
104 }
105 }
106 Err(anyhow!("connection closed"))
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[gpui::test]
115 async fn test_buffer_size() {
116 let (tx, rx) = futures::channel::mpsc::unbounded();
117 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
118 sink.write(Message::Envelope(Envelope {
119 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
120 root_name: "abcdefg".repeat(10),
121 ..Default::default()
122 })),
123 ..Default::default()
124 }))
125 .await
126 .unwrap();
127 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
128 sink.write(Message::Envelope(Envelope {
129 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
130 root_name: "abcdefg".repeat(1000000),
131 ..Default::default()
132 })),
133 ..Default::default()
134 }))
135 .await
136 .unwrap();
137 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
138
139 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
140 stream.read().await.unwrap();
141 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
142 stream.read().await.unwrap();
143 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
144 }
145}