1use futures_io::{AsyncRead, AsyncWrite};
2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
3use prost::Message;
4use std::io;
5
6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
7
8pub trait Request {
9 type Response;
10}
11
12impl Request for from_client::Auth {
13 type Response = from_server::Ack;
14}
15
16/// A stream of protobuf messages.
17pub struct MessageStream<T> {
18 byte_stream: T,
19 buffer: Vec<u8>,
20}
21
22impl<T> MessageStream<T> {
23 pub fn new(byte_stream: T) -> Self {
24 Self {
25 byte_stream,
26 buffer: Default::default(),
27 }
28 }
29}
30
31impl<T> MessageStream<T>
32where
33 T: AsyncWrite + Unpin,
34{
35 /// Write a given protobuf message to the stream.
36 pub async fn write_message(&mut self, message: &impl Message) -> futures_io::Result<()> {
37 self.buffer.clear();
38 message.encode_length_delimited(&mut self.buffer).unwrap();
39 self.byte_stream.write_all(&self.buffer).await
40 }
41}
42
43impl<T> MessageStream<T>
44where
45 T: AsyncRead + Unpin,
46{
47 /// Read a protobuf message of the given type from the stream.
48 pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
49 // Ensure the buffer is large enough to hold the maximum delimiter length
50 const MAX_DELIMITER_LEN: usize = 10;
51 self.buffer.clear();
52 self.buffer.resize(MAX_DELIMITER_LEN, 0);
53
54 // Read until a complete length delimiter can be decoded.
55 let mut read_start_offset = 0;
56 let (encoded_len, delimiter_len) = loop {
57 let bytes_read = self
58 .byte_stream
59 .read(&mut self.buffer[read_start_offset..])
60 .await?;
61 read_start_offset += bytes_read;
62
63 let mut buffer = &self.buffer[0..read_start_offset];
64 match prost::decode_length_delimiter(&mut buffer) {
65 Err(_) => {
66 if read_start_offset >= MAX_DELIMITER_LEN {
67 return Err(io::Error::new(
68 io::ErrorKind::InvalidData,
69 "invalid message length delimiter",
70 ));
71 }
72 }
73 Ok(encoded_len) => {
74 let delimiter_len = read_start_offset - buffer.len();
75 break (encoded_len, delimiter_len);
76 }
77 }
78 };
79
80 // Read the message itself.
81 self.buffer.resize(delimiter_len + encoded_len, 0);
82 self.byte_stream
83 .read_exact(&mut self.buffer[read_start_offset..])
84 .await?;
85 let message = M::decode(&self.buffer[delimiter_len..])?;
86
87 Ok(message)
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use std::{
95 pin::Pin,
96 task::{Context, Poll},
97 };
98
99 #[test]
100 fn test_round_trip_message() {
101 smol::block_on(async {
102 let byte_stream = ChunkedStream {
103 bytes: Vec::new(),
104 read_offset: 0,
105 chunk_size: 3,
106 };
107
108 // In reality there will never be both `FromClient` and `FromServer` messages
109 // sent in the same direction on the same stream.
110 let message1 = FromClient {
111 id: 3,
112 variant: Some(from_client::Variant::Auth(from_client::Auth {
113 user_id: 5,
114 access_token: "the-access-token".into(),
115 })),
116 };
117 let message2 = FromServer {
118 request_id: Some(4),
119 variant: Some(from_server::Variant::Ack(from_server::Ack {
120 error_message: Some(
121 format!(
122 "a {}long error message that requires a two-byte length delimiter",
123 "very ".repeat(60)
124 )
125 .into(),
126 ),
127 })),
128 };
129
130 let mut message_stream = MessageStream::new(byte_stream);
131 message_stream.write_message(&message1).await.unwrap();
132 message_stream.write_message(&message2).await.unwrap();
133 let decoded_message1 = message_stream.read_message::<FromClient>().await.unwrap();
134 let decoded_message2 = message_stream.read_message::<FromServer>().await.unwrap();
135 assert_eq!(decoded_message1, message1);
136 assert_eq!(decoded_message2, message2);
137 });
138 }
139
140 #[test]
141 fn test_read_message_when_length_delimiter_is_not_complete_in_first_read() {
142 smol::block_on(async {
143 let byte_stream = ChunkedStream {
144 bytes: Vec::new(),
145 read_offset: 0,
146 chunk_size: 2,
147 };
148
149 // This message is so long that its length delimiter requires three bytes,
150 // so it won't be delivered in a single read from the chunked byte stream.
151 let message = FromServer {
152 request_id: Some(4),
153 variant: Some(from_server::Variant::Ack(from_server::Ack {
154 error_message: Some("long ".repeat(256 * 256).into()),
155 })),
156 };
157 assert!(prost::length_delimiter_len(message.encoded_len()) > byte_stream.chunk_size);
158
159 let mut message_stream = MessageStream::new(byte_stream);
160 message_stream.write_message(&message).await.unwrap();
161 let decoded_message = message_stream.read_message::<FromServer>().await.unwrap();
162 assert_eq!(decoded_message, message);
163 });
164 }
165
166 #[test]
167 fn test_protobuf_parse_error() {
168 smol::block_on(async {
169 let byte_stream = ChunkedStream {
170 bytes: Vec::new(),
171 read_offset: 0,
172 chunk_size: 2,
173 };
174
175 let message = FromClient {
176 id: 3,
177 variant: Some(from_client::Variant::Auth(from_client::Auth {
178 user_id: 5,
179 access_token: "the-access-token".into(),
180 })),
181 };
182
183 let mut message_stream = MessageStream::new(byte_stream);
184 message_stream.write_message(&message).await.unwrap();
185
186 // Read the wrong type of message from the stream.
187 let result = message_stream.read_message::<FromServer>().await;
188 assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
189 });
190 }
191
192 struct ChunkedStream {
193 bytes: Vec<u8>,
194 read_offset: usize,
195 chunk_size: usize,
196 }
197
198 impl AsyncWrite for ChunkedStream {
199 fn poll_write(
200 mut self: Pin<&mut Self>,
201 _: &mut Context<'_>,
202 buf: &[u8],
203 ) -> Poll<io::Result<usize>> {
204 let bytes_written = buf.len().min(self.chunk_size);
205 self.bytes.extend_from_slice(&buf[0..bytes_written]);
206 Poll::Ready(Ok(bytes_written))
207 }
208
209 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
210 Poll::Ready(Ok(()))
211 }
212
213 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
214 Poll::Ready(Ok(()))
215 }
216 }
217
218 impl AsyncRead for ChunkedStream {
219 fn poll_read(
220 mut self: Pin<&mut Self>,
221 _: &mut Context<'_>,
222 buf: &mut [u8],
223 ) -> Poll<io::Result<usize>> {
224 let bytes_read = buf
225 .len()
226 .min(self.chunk_size)
227 .min(self.bytes.len() - self.read_offset);
228 let end_offset = self.read_offset + bytes_read;
229 buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
230 self.read_offset = end_offset;
231 Poll::Ready(Ok(bytes_read))
232 }
233 }
234}