1#![allow(non_snake_case)]
2
3pub use ::proto::*;
4
5use async_tungstenite::tungstenite::Message as WebSocketMessage;
6use futures::{SinkExt as _, StreamExt as _};
7use proto::Message as _;
8use std::time::Instant;
9use std::{fmt::Debug, io};
10
11const KIB: usize = 1024;
12const MIB: usize = KIB * 1024;
13const MAX_BUFFER_LEN: usize = MIB;
14
15/// A stream of protobuf messages.
16pub struct MessageStream<S> {
17 stream: S,
18 encoding_buffer: Vec<u8>,
19}
20
21#[derive(Debug)]
22pub enum Message {
23 Envelope(Envelope),
24 Ping,
25 Pong,
26}
27
28impl<S> MessageStream<S> {
29 pub fn new(stream: S) -> Self {
30 Self {
31 stream,
32 encoding_buffer: Vec::new(),
33 }
34 }
35}
36
37impl<S> MessageStream<S>
38where
39 S: futures::Sink<WebSocketMessage, Error = anyhow::Error> + Unpin,
40{
41 pub async fn write(&mut self, message: Message) -> anyhow::Result<()> {
42 #[cfg(any(test, feature = "test-support"))]
43 const COMPRESSION_LEVEL: i32 = -7;
44
45 #[cfg(not(any(test, feature = "test-support")))]
46 const COMPRESSION_LEVEL: i32 = 4;
47
48 match message {
49 Message::Envelope(message) => {
50 self.encoding_buffer.reserve(message.encoded_len());
51 message
52 .encode(&mut self.encoding_buffer)
53 .map_err(io::Error::from)?;
54 let buffer =
55 zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
56 .unwrap();
57
58 self.encoding_buffer.clear();
59 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
60 self.stream
61 .send(WebSocketMessage::Binary(buffer.into()))
62 .await?;
63 }
64 Message::Ping => {
65 self.stream
66 .send(WebSocketMessage::Ping(Default::default()))
67 .await?;
68 }
69 Message::Pong => {
70 self.stream
71 .send(WebSocketMessage::Pong(Default::default()))
72 .await?;
73 }
74 }
75
76 Ok(())
77 }
78}
79
80impl<S> MessageStream<S>
81where
82 S: futures::Stream<Item = anyhow::Result<WebSocketMessage>> + Unpin,
83{
84 pub async fn read(&mut self) -> anyhow::Result<(Message, Instant)> {
85 while let Some(bytes) = self.stream.next().await {
86 let received_at = Instant::now();
87 match bytes? {
88 WebSocketMessage::Binary(bytes) => {
89 zstd::stream::copy_decode(
90 zstd::zstd_safe::WriteBuf::as_slice(&*bytes),
91 &mut self.encoding_buffer,
92 )?;
93 let envelope = Envelope::decode(self.encoding_buffer.as_slice())
94 .map_err(io::Error::from)?;
95
96 self.encoding_buffer.clear();
97 self.encoding_buffer.shrink_to(MAX_BUFFER_LEN);
98 return Ok((Message::Envelope(envelope), received_at));
99 }
100 WebSocketMessage::Ping(_) => return Ok((Message::Ping, received_at)),
101 WebSocketMessage::Pong(_) => return Ok((Message::Pong, received_at)),
102 WebSocketMessage::Close(_) => break,
103 _ => {}
104 }
105 }
106 anyhow::bail!("connection closed");
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113
114 #[gpui::test]
115 async fn test_buffer_size() {
116 let (tx, rx) = futures::channel::mpsc::unbounded();
117 let mut sink = MessageStream::new(tx.sink_map_err(|_| anyhow::anyhow!("")));
118 sink.write(Message::Envelope(Envelope {
119 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
120 root_name: "abcdefg".repeat(10),
121 ..Default::default()
122 })),
123 ..Default::default()
124 }))
125 .await
126 .unwrap();
127 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
128 sink.write(Message::Envelope(Envelope {
129 payload: Some(envelope::Payload::UpdateWorktree(UpdateWorktree {
130 root_name: "abcdefg".repeat(1000000),
131 ..Default::default()
132 })),
133 ..Default::default()
134 }))
135 .await
136 .unwrap();
137 assert!(sink.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
138
139 let mut stream = MessageStream::new(rx.map(anyhow::Ok));
140 stream.read().await.unwrap();
141 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
142 stream.read().await.unwrap();
143 assert!(stream.encoding_buffer.capacity() <= MAX_BUFFER_LEN);
144 }
145}