1use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
2use prost::Message;
3use std::{
4 convert::TryInto,
5 io,
6 time::{Duration, SystemTime, UNIX_EPOCH},
7};
8
9include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
10
11pub trait EnvelopedMessage: Clone + Sized + Send + 'static {
12 const NAME: &'static str;
13 fn into_envelope(
14 self,
15 id: u32,
16 responding_to: Option<u32>,
17 original_sender_id: Option<u32>,
18 ) -> Envelope;
19 fn matches_envelope(envelope: &Envelope) -> bool;
20 fn from_envelope(envelope: Envelope) -> Option<Self>;
21}
22
23pub trait RequestMessage: EnvelopedMessage {
24 type Response: EnvelopedMessage;
25}
26
27macro_rules! message {
28 ($name:ident) => {
29 impl EnvelopedMessage for $name {
30 const NAME: &'static str = std::stringify!($name);
31
32 fn into_envelope(
33 self,
34 id: u32,
35 responding_to: Option<u32>,
36 original_sender_id: Option<u32>,
37 ) -> Envelope {
38 Envelope {
39 id,
40 responding_to,
41 original_sender_id,
42 payload: Some(envelope::Payload::$name(self)),
43 }
44 }
45
46 fn matches_envelope(envelope: &Envelope) -> bool {
47 matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
48 }
49
50 fn from_envelope(envelope: Envelope) -> Option<Self> {
51 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
52 Some(msg)
53 } else {
54 None
55 }
56 }
57 }
58 };
59}
60
61macro_rules! request_message {
62 ($req:ident, $resp:ident) => {
63 message!($req);
64 message!($resp);
65 impl RequestMessage for $req {
66 type Response = $resp;
67 }
68 };
69}
70
71request_message!(Auth, AuthResponse);
72request_message!(ShareWorktree, ShareWorktreeResponse);
73request_message!(OpenWorktree, OpenWorktreeResponse);
74message!(UpdateWorktree);
75message!(CloseWorktree);
76request_message!(OpenBuffer, OpenBufferResponse);
77message!(CloseBuffer);
78message!(UpdateBuffer);
79request_message!(SaveBuffer, BufferSaved);
80message!(AddPeer);
81message!(RemovePeer);
82
83/// A stream of protobuf messages.
84pub struct MessageStream<T> {
85 byte_stream: T,
86 buffer: Vec<u8>,
87 upcoming_message_len: Option<usize>,
88}
89
90impl<T> MessageStream<T> {
91 pub fn new(byte_stream: T) -> Self {
92 Self {
93 byte_stream,
94 buffer: Default::default(),
95 upcoming_message_len: None,
96 }
97 }
98
99 pub fn inner_mut(&mut self) -> &mut T {
100 &mut self.byte_stream
101 }
102}
103
104impl<T> MessageStream<T>
105where
106 T: AsyncWrite + Unpin,
107{
108 /// Write a given protobuf message to the stream.
109 pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
110 let message_len: u32 = message
111 .encoded_len()
112 .try_into()
113 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
114 self.buffer.clear();
115 self.buffer.extend_from_slice(&message_len.to_be_bytes());
116 message.encode(&mut self.buffer)?;
117 self.byte_stream.write_all(&self.buffer).await
118 }
119}
120
121impl<T> MessageStream<T>
122where
123 T: AsyncRead + Unpin,
124{
125 /// Read a protobuf message of the given type from the stream.
126 pub async fn read_message(&mut self) -> io::Result<Envelope> {
127 loop {
128 if let Some(upcoming_message_len) = self.upcoming_message_len {
129 self.buffer.resize(upcoming_message_len, 0);
130 self.byte_stream.read_exact(&mut self.buffer).await?;
131 self.upcoming_message_len = None;
132 return Ok(Envelope::decode(self.buffer.as_slice())?);
133 } else {
134 self.buffer.resize(4, 0);
135 self.byte_stream.read_exact(&mut self.buffer).await?;
136 self.upcoming_message_len = Some(u32::from_be_bytes([
137 self.buffer[0],
138 self.buffer[1],
139 self.buffer[2],
140 self.buffer[3],
141 ]) as usize);
142 }
143 }
144 }
145}
146
147impl Into<SystemTime> for Timestamp {
148 fn into(self) -> SystemTime {
149 UNIX_EPOCH
150 .checked_add(Duration::new(self.seconds, self.nanos))
151 .unwrap()
152 }
153}
154
155impl From<SystemTime> for Timestamp {
156 fn from(time: SystemTime) -> Self {
157 let duration = time.duration_since(UNIX_EPOCH).unwrap();
158 Self {
159 seconds: duration.as_secs(),
160 nanos: duration.subsec_nanos(),
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use std::{
169 pin::Pin,
170 task::{Context, Poll},
171 };
172
173 #[test]
174 fn test_round_trip_message() {
175 smol::block_on(async {
176 let byte_stream = ChunkedStream {
177 bytes: Vec::new(),
178 read_offset: 0,
179 chunk_size: 3,
180 };
181
182 let message1 = Auth {
183 user_id: 5,
184 access_token: "the-access-token".into(),
185 }
186 .into_envelope(3, None, None);
187
188 let message2 = OpenBuffer {
189 worktree_id: 0,
190 path: "some/path".to_string(),
191 }
192 .into_envelope(5, None, None);
193
194 let mut message_stream = MessageStream::new(byte_stream);
195 message_stream.write_message(&message1).await.unwrap();
196 message_stream.write_message(&message2).await.unwrap();
197 let decoded_message1 = message_stream.read_message().await.unwrap();
198 let decoded_message2 = message_stream.read_message().await.unwrap();
199 assert_eq!(decoded_message1, message1);
200 assert_eq!(decoded_message2, message2);
201 });
202 }
203
204 struct ChunkedStream {
205 bytes: Vec<u8>,
206 read_offset: usize,
207 chunk_size: usize,
208 }
209
210 impl AsyncWrite for ChunkedStream {
211 fn poll_write(
212 mut self: Pin<&mut Self>,
213 _: &mut Context<'_>,
214 buf: &[u8],
215 ) -> Poll<io::Result<usize>> {
216 let bytes_written = buf.len().min(self.chunk_size);
217 self.bytes.extend_from_slice(&buf[0..bytes_written]);
218 Poll::Ready(Ok(bytes_written))
219 }
220
221 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
222 Poll::Ready(Ok(()))
223 }
224
225 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
226 Poll::Ready(Ok(()))
227 }
228 }
229
230 impl AsyncRead for ChunkedStream {
231 fn poll_read(
232 mut self: Pin<&mut Self>,
233 _: &mut Context<'_>,
234 buf: &mut [u8],
235 ) -> Poll<io::Result<usize>> {
236 let bytes_read = buf
237 .len()
238 .min(self.chunk_size)
239 .min(self.bytes.len() - self.read_offset);
240 let end_offset = self.read_offset + bytes_read;
241 buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
242 self.read_offset = end_offset;
243 Poll::Ready(Ok(bytes_read))
244 }
245 }
246}