1use futures_io::{AsyncRead, AsyncWrite};
2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
3use prost::Message;
4use std::{convert::TryInto, io};
5
6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
7
8/// A message that the client can send to the server.
9pub trait ClientMessage: Sized {
10 fn to_variant(self) -> from_client::Variant;
11 fn from_variant(variant: from_client::Variant) -> Option<Self>;
12}
13
14/// A message that the server can send to the client.
15pub trait ServerMessage: Sized {
16 fn to_variant(self) -> from_server::Variant;
17 fn from_variant(variant: from_server::Variant) -> Option<Self>;
18}
19
20/// A message that the client can send to the server, where the server must respond with a single
21/// message of a certain type.
22pub trait RequestMessage: ClientMessage {
23 type Response: ServerMessage;
24}
25
26/// A message that the client can send to the server, where the server must respond with a series of
27/// messages of a certain type.
28pub trait SubscribeMessage: ClientMessage {
29 type Event: ServerMessage;
30}
31
32/// A message that the client can send to the server, where the server will not respond.
33pub trait SendMessage: ClientMessage {}
34
35macro_rules! directed_message {
36 ($name:ident, $direction_trait:ident, $direction_module:ident) => {
37 impl $direction_trait for $direction_module::$name {
38 fn to_variant(self) -> $direction_module::Variant {
39 $direction_module::Variant::$name(self)
40 }
41
42 fn from_variant(variant: $direction_module::Variant) -> Option<Self> {
43 if let $direction_module::Variant::$name(msg) = variant {
44 Some(msg)
45 } else {
46 None
47 }
48 }
49 }
50 };
51}
52
53macro_rules! request_message {
54 ($req:ident, $resp:ident) => {
55 directed_message!($req, ClientMessage, from_client);
56 directed_message!($resp, ServerMessage, from_server);
57 impl RequestMessage for from_client::$req {
58 type Response = from_server::$resp;
59 }
60 };
61}
62
63macro_rules! send_message {
64 ($msg:ident) => {
65 directed_message!($msg, ClientMessage, from_client);
66 impl SendMessage for from_client::$msg {}
67 };
68}
69
70request_message!(Auth, AuthResponse);
71request_message!(NewWorktree, NewWorktreeResponse);
72request_message!(ShareWorktree, ShareWorktreeResponse);
73send_message!(UploadFile);
74
75/// A stream of protobuf messages.
76pub struct MessageStream<T> {
77 byte_stream: T,
78 buffer: Vec<u8>,
79}
80
81impl<T> MessageStream<T> {
82 pub fn new(byte_stream: T) -> Self {
83 Self {
84 byte_stream,
85 buffer: Default::default(),
86 }
87 }
88
89 pub fn inner_mut(&mut self) -> &mut T {
90 &mut self.byte_stream
91 }
92}
93
94impl<T> MessageStream<T>
95where
96 T: AsyncWrite + Unpin,
97{
98 /// Write a given protobuf message to the stream.
99 pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> {
100 let message_len: u32 = message
101 .encoded_len()
102 .try_into()
103 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
104 self.buffer.clear();
105 self.buffer.extend_from_slice(&message_len.to_be_bytes());
106 message.encode(&mut self.buffer)?;
107 self.byte_stream.write_all(&self.buffer).await
108 }
109}
110
111impl<T> MessageStream<T>
112where
113 T: AsyncRead + Unpin,
114{
115 /// Read a protobuf message of the given type from the stream.
116 pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
117 let mut delimiter_buf = [0; 4];
118 self.byte_stream.read_exact(&mut delimiter_buf).await?;
119 let message_len = u32::from_be_bytes(delimiter_buf) as usize;
120 self.buffer.resize(message_len, 0);
121 self.byte_stream.read_exact(&mut self.buffer).await?;
122 Ok(M::decode(self.buffer.as_slice())?)
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use std::{
130 pin::Pin,
131 task::{Context, Poll},
132 };
133
134 #[test]
135 fn test_round_trip_message() {
136 smol::block_on(async {
137 let byte_stream = ChunkedStream {
138 bytes: Vec::new(),
139 read_offset: 0,
140 chunk_size: 3,
141 };
142
143 let message1 = FromClient {
144 id: 3,
145 variant: Some(from_client::Variant::Auth(from_client::Auth {
146 user_id: 5,
147 access_token: "the-access-token".into(),
148 })),
149 };
150 let message2 = FromClient {
151 id: 4,
152 variant: Some(from_client::Variant::UploadFile(from_client::UploadFile {
153 path: Vec::new(),
154 content: format!(
155 "a {}long error message that requires a two-byte length delimiter",
156 "very ".repeat(60)
157 )
158 .into(),
159 })),
160 };
161
162 let mut message_stream = MessageStream::new(byte_stream);
163 message_stream.write_message(&message1).await.unwrap();
164 message_stream.write_message(&message2).await.unwrap();
165 let decoded_message1 = message_stream.read_message::<FromClient>().await.unwrap();
166 let decoded_message2 = message_stream.read_message::<FromClient>().await.unwrap();
167 assert_eq!(decoded_message1, message1);
168 assert_eq!(decoded_message2, message2);
169 });
170 }
171
172 struct ChunkedStream {
173 bytes: Vec<u8>,
174 read_offset: usize,
175 chunk_size: usize,
176 }
177
178 impl AsyncWrite for ChunkedStream {
179 fn poll_write(
180 mut self: Pin<&mut Self>,
181 _: &mut Context<'_>,
182 buf: &[u8],
183 ) -> Poll<io::Result<usize>> {
184 let bytes_written = buf.len().min(self.chunk_size);
185 self.bytes.extend_from_slice(&buf[0..bytes_written]);
186 Poll::Ready(Ok(bytes_written))
187 }
188
189 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
190 Poll::Ready(Ok(()))
191 }
192
193 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
194 Poll::Ready(Ok(()))
195 }
196 }
197
198 impl AsyncRead for ChunkedStream {
199 fn poll_read(
200 mut self: Pin<&mut Self>,
201 _: &mut Context<'_>,
202 buf: &mut [u8],
203 ) -> Poll<io::Result<usize>> {
204 let bytes_read = buf
205 .len()
206 .min(self.chunk_size)
207 .min(self.bytes.len() - self.read_offset);
208 let end_offset = self.read_offset + bytes_read;
209 buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
210 self.read_offset = end_offset;
211 Poll::Ready(Ok(bytes_read))
212 }
213 }
214}