1use futures_io::{AsyncRead, AsyncWrite};
2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
3use prost::Message;
4use std::io;
5
6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
7
8use from_client as client;
9use from_server as server;
10
11pub trait Request {
12 type Response;
13}
14
15macro_rules! request_response {
16 ($req:path, $resp:path) => {
17 impl Request for $req {
18 type Response = $resp;
19 }
20 };
21}
22
23request_response!(client::Auth, server::AuthResponse);
24request_response!(client::NewWorktree, server::NewWorktreeResponse);
25request_response!(client::ShareWorktree, server::ShareWorktreeResponse);
26
27/// A stream of protobuf messages.
28pub struct MessageStream<T> {
29 byte_stream: T,
30 buffer: Vec<u8>,
31}
32
33impl<T> MessageStream<T> {
34 pub fn new(byte_stream: T) -> Self {
35 Self {
36 byte_stream,
37 buffer: Default::default(),
38 }
39 }
40}
41
42impl<T> MessageStream<T>
43where
44 T: AsyncWrite + Unpin,
45{
46 /// Write a given protobuf message to the stream.
47 pub async fn write_message(&mut self, message: &impl Message) -> futures_io::Result<()> {
48 self.buffer.clear();
49 message.encode_length_delimited(&mut self.buffer).unwrap();
50 self.byte_stream.write_all(&self.buffer).await
51 }
52}
53
54impl<T> MessageStream<T>
55where
56 T: AsyncRead + Unpin,
57{
58 /// Read a protobuf message of the given type from the stream.
59 pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
60 // Ensure the buffer is large enough to hold the maximum delimiter length
61 const MAX_DELIMITER_LEN: usize = 10;
62 self.buffer.clear();
63 self.buffer.resize(MAX_DELIMITER_LEN, 0);
64
65 // Read until a complete length delimiter can be decoded.
66 let mut read_start_offset = 0;
67 let (encoded_len, delimiter_len) = loop {
68 let bytes_read = self
69 .byte_stream
70 .read(&mut self.buffer[read_start_offset..])
71 .await?;
72 read_start_offset += bytes_read;
73
74 let mut buffer = &self.buffer[0..read_start_offset];
75 match prost::decode_length_delimiter(&mut buffer) {
76 Err(_) => {
77 if read_start_offset >= MAX_DELIMITER_LEN {
78 return Err(io::Error::new(
79 io::ErrorKind::InvalidData,
80 "invalid message length delimiter",
81 ));
82 }
83 }
84 Ok(encoded_len) => {
85 let delimiter_len = read_start_offset - buffer.len();
86 break (encoded_len, delimiter_len);
87 }
88 }
89 };
90
91 // Read the message itself.
92 self.buffer.resize(delimiter_len + encoded_len, 0);
93 self.byte_stream
94 .read_exact(&mut self.buffer[read_start_offset..])
95 .await?;
96 let message = M::decode(&self.buffer[delimiter_len..])?;
97
98 Ok(message)
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use std::{
106 pin::Pin,
107 task::{Context, Poll},
108 };
109
110 #[test]
111 fn test_round_trip_message() {
112 smol::block_on(async {
113 let byte_stream = ChunkedStream {
114 bytes: Vec::new(),
115 read_offset: 0,
116 chunk_size: 3,
117 };
118
119 // In reality there will never be both `FromClient` and `FromServer` messages
120 // sent in the same direction on the same stream.
121 let message1 = FromClient {
122 id: 3,
123 variant: Some(from_client::Variant::Auth(from_client::Auth {
124 user_id: 5,
125 access_token: "the-access-token".into(),
126 })),
127 };
128 let message2 = FromServer {
129 request_id: Some(4),
130 variant: Some(from_server::Variant::Ack(from_server::Ack {
131 error_message: Some(
132 format!(
133 "a {}long error message that requires a two-byte length delimiter",
134 "very ".repeat(60)
135 )
136 .into(),
137 ),
138 })),
139 };
140
141 let mut message_stream = MessageStream::new(byte_stream);
142 message_stream.write_message(&message1).await.unwrap();
143 message_stream.write_message(&message2).await.unwrap();
144 let decoded_message1 = message_stream.read_message::<FromClient>().await.unwrap();
145 let decoded_message2 = message_stream.read_message::<FromServer>().await.unwrap();
146 assert_eq!(decoded_message1, message1);
147 assert_eq!(decoded_message2, message2);
148 });
149 }
150
151 #[test]
152 fn test_read_message_when_length_delimiter_is_not_complete_in_first_read() {
153 smol::block_on(async {
154 let byte_stream = ChunkedStream {
155 bytes: Vec::new(),
156 read_offset: 0,
157 chunk_size: 2,
158 };
159
160 // This message is so long that its length delimiter requires three bytes,
161 // so it won't be delivered in a single read from the chunked byte stream.
162 let message = FromServer {
163 request_id: Some(4),
164 variant: Some(from_server::Variant::Ack(from_server::Ack {
165 error_message: Some("long ".repeat(256 * 256).into()),
166 })),
167 };
168 assert!(prost::length_delimiter_len(message.encoded_len()) > byte_stream.chunk_size);
169
170 let mut message_stream = MessageStream::new(byte_stream);
171 message_stream.write_message(&message).await.unwrap();
172 let decoded_message = message_stream.read_message::<FromServer>().await.unwrap();
173 assert_eq!(decoded_message, message);
174 });
175 }
176
177 #[test]
178 fn test_protobuf_parse_error() {
179 smol::block_on(async {
180 let byte_stream = ChunkedStream {
181 bytes: Vec::new(),
182 read_offset: 0,
183 chunk_size: 2,
184 };
185
186 let message = FromClient {
187 id: 3,
188 variant: Some(from_client::Variant::Auth(from_client::Auth {
189 user_id: 5,
190 access_token: "the-access-token".into(),
191 })),
192 };
193
194 let mut message_stream = MessageStream::new(byte_stream);
195 message_stream.write_message(&message).await.unwrap();
196
197 // Read the wrong type of message from the stream.
198 let result = message_stream.read_message::<FromServer>().await;
199 assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
200 });
201 }
202
203 struct ChunkedStream {
204 bytes: Vec<u8>,
205 read_offset: usize,
206 chunk_size: usize,
207 }
208
209 impl AsyncWrite for ChunkedStream {
210 fn poll_write(
211 mut self: Pin<&mut Self>,
212 _: &mut Context<'_>,
213 buf: &[u8],
214 ) -> Poll<io::Result<usize>> {
215 let bytes_written = buf.len().min(self.chunk_size);
216 self.bytes.extend_from_slice(&buf[0..bytes_written]);
217 Poll::Ready(Ok(bytes_written))
218 }
219
220 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
221 Poll::Ready(Ok(()))
222 }
223
224 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
225 Poll::Ready(Ok(()))
226 }
227 }
228
229 impl AsyncRead for ChunkedStream {
230 fn poll_read(
231 mut self: Pin<&mut Self>,
232 _: &mut Context<'_>,
233 buf: &mut [u8],
234 ) -> Poll<io::Result<usize>> {
235 let bytes_read = buf
236 .len()
237 .min(self.chunk_size)
238 .min(self.bytes.len() - self.read_offset);
239 let end_offset = self.read_offset + bytes_read;
240 buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
241 self.read_offset = end_offset;
242 Poll::Ready(Ok(bytes_read))
243 }
244 }
245}