1#![allow(non_snake_case)]
2
3pub use ::proto::*;
4
5use anyhow::anyhow;
6use async_tungstenite::tungstenite::Message as WebSocketMessage;
7use futures::{SinkExt as _, StreamExt as _};
8use proto::Message as _;
9use std::time::Instant;
10use std::{fmt::Debug, io};
11
12const KIB: usize = 1024;
13const MIB: usize = KIB * 1024;
14const MAX_BUFFER_LEN: usize = MIB;
15
16/// A stream of protobuf messages.
17pub struct MessageStream<S> {
18 stream: S,
19 encoding_buffer: Vec<u8>,
20}
21
22#[allow(clippy::large_enum_variant)]
23#[derive(Debug)]
24pub enum Message {
25 Envelope(Envelope),
26 Ping,
27 Pong,
28}
29
30impl<S> MessageStream<S> {
31 pub fn new(stream: S) -> Self {
32 Self {
33 stream,
34 encoding_buffer: Vec::new(),
35 }
36 }
37}
38
39impl<S> MessageStream<S>
40where
41 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
42{
43 pub async fn write(&mut self, message: Message) -> Result<(), anyhow::Error> {
44 #[cfg(any(test, feature = "test-support"))]
45 const COMPRESSION_LEVEL: i32 = -7;
46
47 #[cfg(not(any(test, feature = "test-support")))]
48 const COMPRESSION_LEVEL: i32 = 4;
49
50 match message {
51 Message::Envelope(message) => {
52 self.encoding_buffer.reserve(message.encoded_len());
53 message
54 .encode(&mut self.encoding_buffer)
55 .map_err(io::Error::from)?;
56 let buffer =
57 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
58 .unwrap();
59
60 self.encoding_buffer.clear();
61 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
62 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
63 }
64 Message::Ping => {
65 self.stream
66 .send(WebSocketMessage::Ping(Default::default()))
67 .await?;
68 }
69 Message::Pong => {
70 self.stream
71 .send(WebSocketMessage::Pong(Default::default()))
72 .await?;
73 }
74 }
75
76 Ok(())
77 }
78}
79
80impl<S> MessageStream<S>
81where
82 S: futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>> + Unpin,
83{
84 pub async fn read(&mut self) -> Result<(Message, Instant), anyhow::Error> {
85 while let Some(bytes) = self.stream.next().await {
86 let received_at = Instant::now();
87 match bytes? {
88 WebSocketMessage::Binary(bytes) => {
89 zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer)?;
90 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
91 .map_err(io::Error::from)?;
92
93 self.encoding_buffer.clear();
94 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
95 return Ok((Message::Envelope(envelope), received_at));
96 }
97 WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)),
98 WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)),
99 WebSocketMessage::Close(_) => break,
100 _ => {}
101 }
102 }
103 Err(anyhow!("connection closed"))
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[gpui::test]
112 async fn test_buffer_size() {
113 let (tx, rx) = futures::channel::mpsc::unbounded();
114 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow!("")));
115 sink.write(Message::Envelope(Envelope {
116 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
117 root_name: "abcdefg".repeat(10),
118 ..Default::default()
119 })),
120 ..Default::default()
121 }))
122 .await
123 .unwrap();
124 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
125 sink.write(Message::Envelope(Envelope {
126 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
127 root_name: "abcdefg".repeat(1000000),
128 ..Default::default()
129 })),
130 ..Default::default()
131 }))
132 .await
133 .unwrap();
134 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
135
136 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
137 stream.read().await.unwrap();
138 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
139 stream.read().await.unwrap();
140 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
141 }
142}