1#![allow(non_snake_case)]
2
3use anyhow::anyhow;
4use async_tungstenite::tungstenite::Message as WebSocketMessage;
5use futures::{SinkExt as _, StreamExt as _};
6pub use proto::{Message as _, *};
7use std::time::Instant;
8use std::{fmt::Debug, io};
9
10const KIB: usize = 1024;
11const MIB: usize = KIB * 1024;
12const MAX_BUFFER_LEN: usize = MIB;
13
14/// A stream of protobuf messages.
15pub struct MessageStream<S> {
16 stream: S,
17 encoding_buffer: Vec<u8>,
18}
19
20#[allow(clippy::large_enum_variant)]
21#[derive(Debug)]
22pub enum Message {
23 Envelope(Envelope),
24 Ping,
25 Pong,
26}
27
28impl<S> MessageStream<S> {
29 pub fn new(stream: S) -> Self {
30 Self {
31 stream,
32 encoding_buffer: Vec::new(),
33 }
34 }
35
36 pub fn inner_mut(&mut self) -> &mut S {
37 &mut self.stream
38 }
39}
40
41impl<S> MessageStream<S>
42where
43 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
44{
45 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
46 #[cfg(any(test, feature = "test-support"))]
47 const COMPRESSION_LEVEL: i32 = -7;
48
49 #[cfg(not(any(test, feature = "test-support")))]
50 const COMPRESSION_LEVEL: i32 = 4;
51
52 match message {
53 Message::Envelope(message) => {
54 self.encoding_buffer.reserve(message.encoded_len());
55 message
56 .encode(&mut self.encoding_buffer)
57 .map_err(io::Error::from)?;
58 let buffer =
59 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
60 .unwrap();
61
62 self.encoding_buffer.clear();
63 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
64 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
65 }
66 Message::Ping => {
67 self.stream
68 .send(WebSocketMessage::Ping(Default::default()))
69 .await?;
70 }
71 Message::Pong => {
72 self.stream
73 .send(WebSocketMessage::Pong(Default::default()))
74 .await?;
75 }
76 }
77
78 Ok(())
79 }
80}
81
82impl<S> MessageStream<S>
83where
84 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
85{
86 pub async fn read(&mut self) -> Result<(Message, Instant), anyhow::Error> {
87 while let Some(bytes) = self.stream.next().await {
88 let received_at = Instant::now();
89 match bytes? {
90 WebSocketMessage::Binary(bytes) => {
91 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer)?;
92 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
93 .map_err(io::Error::from)?;
94
95 self.encoding_buffer.clear();
96 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
97 return Ok((Message::Envelope(envelope), received_at));
98 }
99 WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)),
100 WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)),
101 WebSocketMessage::Close(_) => break,
102 _ => {}
103 }
104 }
105 Err(anyhow!("connection closed"))
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[gpui::test]
114 async fn test_buffer_size() {
115 let (tx, rx) = futures::channel::mpsc::unbounded();
116 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
117 sink.write(Message::Envelope(Envelope {
118 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
119 root_name: "abcdefg".repeat(10),
120 ..Default::default()
121 })),
122 ..Default::default()
123 }))
124 .await
125 .unwrap();
126 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
127 sink.write(Message::Envelope(Envelope {
128 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
129 root_name: "abcdefg".repeat(1000000),
130 ..Default::default()
131 })),
132 ..Default::default()
133 }))
134 .await
135 .unwrap();
136 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
137
138 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
139 stream.read().await.unwrap();
140 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
141 stream.read().await.unwrap();
142 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
143 }
144}