1use futures_io::{AsyncRead, AsyncWrite};
2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
3use prost::Message;
4use std::{convert::TryInto, io};
5
6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
7
8/// A message that the client can send to the server.
9pub trait ClientMessage: Sized {
10 fn to_variant(self) -> from_client::Variant;
11 fn from_variant(variant: from_client::Variant) -> Option<Self>;
12}
13
14/// A message that the server can send to the client.
15pub trait ServerMessage: Sized {
16 fn to_variant(self) -> from_server::Variant;
17 fn from_variant(variant: from_server::Variant) -> Option<Self>;
18}
19
20/// A message that the client can send to the server, where the server must respond with a single
21/// message of a certain type.
22pub trait RequestMessage: ClientMessage {
23 type Response: ServerMessage;
24}
25
26/// A message that the client can send to the server, where the server must respond with a series of
27/// messages of a certain type.
28pub trait SubscribeMessage: ClientMessage {
29 type Event: ServerMessage;
30}
31
32/// A message that the client can send to the server, where the server will not respond.
33pub trait SendMessage: ClientMessage {}
34
35macro_rules! directed_message {
36 ($name:ident, $direction_trait:ident, $direction_module:ident) => {
37 impl $direction_trait for $direction_module::$name {
38 fn to_variant(self) -> $direction_module::Variant {
39 $direction_module::Variant::$name(self)
40 }
41
42 fn from_variant(variant: $direction_module::Variant) -> Option<Self> {
43 if let $direction_module::Variant::$name(msg) = variant {
44 Some(msg)
45 } else {
46 None
47 }
48 }
49 }
50 };
51}
52
53macro_rules! request_message {
54 ($req:ident, $resp:ident) => {
55 directed_message!($req, ClientMessage, from_client);
56 directed_message!($resp, ServerMessage, from_server);
57 impl RequestMessage for from_client::$req {
58 type Response = from_server::$resp;
59 }
60 };
61}
62
63macro_rules! send_message {
64 ($msg:ident) => {
65 directed_message!($msg, ClientMessage, from_client);
66 impl SendMessage for from_client::$msg {}
67 };
68}
69
70macro_rules! subscribe_message {
71 ($subscription:ident, $event:ident) => {
72 directed_message!($subscription, ClientMessage, from_client);
73 directed_message!($event, ServerMessage, from_server);
74 impl SubscribeMessage for from_client::$subscription {
75 type Event = from_server::$event;
76 }
77 };
78}
79
80request_message!(Auth, AuthResponse);
81request_message!(NewWorktree, NewWorktreeResponse);
82request_message!(ShareWorktree, ShareWorktreeResponse);
83send_message!(UploadFile);
84subscribe_message!(SubscribeToPathRequests, PathRequest);
85
86/// A stream of protobuf messages.
87pub struct MessageStream<T> {
88 byte_stream: T,
89 buffer: Vec<u8>,
90}
91
92impl<T> MessageStream<T> {
93 pub fn new(byte_stream: T) -> Self {
94 Self {
95 byte_stream,
96 buffer: Default::default(),
97 }
98 }
99
100 pub fn inner_mut(&mut self) -> &mut T {
101 &mut self.byte_stream
102 }
103}
104
105impl<T> MessageStream<T>
106where
107 T: AsyncWrite + Unpin,
108{
109 /// Write a given protobuf message to the stream.
110 pub async fn write_message(&mut self, message: &impl Message) -> io::Result<()> {
111 let message_len: u32 = message
112 .encoded_len()
113 .try_into()
114 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
115 self.buffer.clear();
116 self.buffer.extend_from_slice(&message_len.to_be_bytes());
117 message.encode(&mut self.buffer)?;
118 self.byte_stream.write_all(&self.buffer).await
119 }
120}
121
122impl<T> MessageStream<T>
123where
124 T: AsyncRead + Unpin,
125{
126 /// Read a protobuf message of the given type from the stream.
127 pub async fn read_message<M: Message + Default>(&mut self) -> futures_io::Result<M> {
128 let mut delimiter_buf = [0; 4];
129 self.byte_stream.read_exact(&mut delimiter_buf).await?;
130 let message_len = u32::from_be_bytes(delimiter_buf) as usize;
131 self.buffer.resize(message_len, 0);
132 self.byte_stream.read_exact(&mut self.buffer).await?;
133 Ok(M::decode(self.buffer.as_slice())?)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use std::{
141 pin::Pin,
142 task::{Context, Poll},
143 };
144
145 #[test]
146 fn test_round_trip_message() {
147 smol::block_on(async {
148 let byte_stream = ChunkedStream {
149 bytes: Vec::new(),
150 read_offset: 0,
151 chunk_size: 3,
152 };
153
154 let message1 = FromClient {
155 id: 3,
156 variant: Some(from_client::Variant::Auth(from_client::Auth {
157 user_id: 5,
158 access_token: "the-access-token".into(),
159 })),
160 };
161 let message2 = FromClient {
162 id: 4,
163 variant: Some(from_client::Variant::UploadFile(from_client::UploadFile {
164 path: Vec::new(),
165 content: format!(
166 "a {}long error message that requires a two-byte length delimiter",
167 "very ".repeat(60)
168 )
169 .into(),
170 })),
171 };
172
173 let mut message_stream = MessageStream::new(byte_stream);
174 message_stream.write_message(&message1).await.unwrap();
175 message_stream.write_message(&message2).await.unwrap();
176 let decoded_message1 = message_stream.read_message::<FromClient>().await.unwrap();
177 let decoded_message2 = message_stream.read_message::<FromClient>().await.unwrap();
178 assert_eq!(decoded_message1, message1);
179 assert_eq!(decoded_message2, message2);
180 });
181 }
182
183 struct ChunkedStream {
184 bytes: Vec<u8>,
185 read_offset: usize,
186 chunk_size: usize,
187 }
188
189 impl AsyncWrite for ChunkedStream {
190 fn poll_write(
191 mut self: Pin<&mut Self>,
192 _: &mut Context<'_>,
193 buf: &[u8],
194 ) -> Poll<io::Result<usize>> {
195 let bytes_written = buf.len().min(self.chunk_size);
196 self.bytes.extend_from_slice(&buf[0..bytes_written]);
197 Poll::Ready(Ok(bytes_written))
198 }
199
200 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
201 Poll::Ready(Ok(()))
202 }
203
204 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
205 Poll::Ready(Ok(()))
206 }
207 }
208
209 impl AsyncRead for ChunkedStream {
210 fn poll_read(
211 mut self: Pin<&mut Self>,
212 _: &mut Context<'_>,
213 buf: &mut [u8],
214 ) -> Poll<io::Result<usize>> {
215 let bytes_read = buf
216 .len()
217 .min(self.chunk_size)
218 .min(self.bytes.len() - self.read_offset);
219 let end_offset = self.read_offset + bytes_read;
220 buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
221 self.read_offset = end_offset;
222 Poll::Ready(Ok(bytes_read))
223 }
224 }
225}