1use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
2use prost::Message;
3use std::{convert::TryInto, io};
4
5include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
6
7pub trait EnvelopedMessage: Sized + Send + 'static {
8 const NAME: &'static str;
9 fn into_envelope(
10 self,
11 id: u32,
12 responding_to: Option<u32>,
13 original_sender_id: Option<u32>,
14 ) -> Envelope;
15 fn matches_envelope(envelope: &Envelope) -> bool;
16 fn from_envelope(envelope: Envelope) -> Option<Self>;
17}
18
19pub trait RequestMessage: EnvelopedMessage {
20 type Response: EnvelopedMessage;
21}
22
23macro_rules! message {
24 ($name:ident) => {
25 impl EnvelopedMessage for $name {
26 const NAME: &'static str = std::stringify!($name);
27
28 fn into_envelope(
29 self,
30 id: u32,
31 responding_to: Option<u32>,
32 original_sender_id: Option<u32>,
33 ) -> Envelope {
34 Envelope {
35 id,
36 responding_to,
37 original_sender_id,
38 payload: Some(envelope::Payload::$name(self)),
39 }
40 }
41
42 fn matches_envelope(envelope: &Envelope) -> bool {
43 matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
44 }
45
46 fn from_envelope(envelope: Envelope) -> Option<Self> {
47 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
48 Some(msg)
49 } else {
50 None
51 }
52 }
53 }
54 };
55}
56
57macro_rules! request_message {
58 ($req:ident, $resp:ident) => {
59 message!($req);
60 message!($resp);
61 impl RequestMessage for $req {
62 type Response = $resp;
63 }
64 };
65}
66
67request_message!(Auth, AuthResponse);
68request_message!(ShareWorktree, ShareWorktreeResponse);
69request_message!(OpenWorktree, OpenWorktreeResponse);
70request_message!(OpenFile, OpenFileResponse);
71message!(CloseFile);
72request_message!(OpenBuffer, OpenBufferResponse);
73
74/// A stream of protobuf messages.
75pub struct MessageStream<T> {
76 byte_stream: T,
77 buffer: Vec<u8>,
78}
79
80impl<T> MessageStream<T> {
81 pub fn new(byte_stream: T) -> Self {
82 Self {
83 byte_stream,
84 buffer: Default::default(),
85 }
86 }
87
88 pub fn inner_mut(&mut self) -> &mut T {
89 &mut self.byte_stream
90 }
91}
92
93impl<T> MessageStream<T>
94where
95 T: AsyncWrite + Unpin,
96{
97 /// Write a given protobuf message to the stream.
98 pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
99 let message_len: u32 = message
100 .encoded_len()
101 .try_into()
102 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
103 self.buffer.clear();
104 self.buffer.extend_from_slice(&message_len.to_be_bytes());
105 message.encode(&mut self.buffer)?;
106 self.byte_stream.write_all(&self.buffer).await
107 }
108}
109
110impl<T> MessageStream<T>
111where
112 T: AsyncRead + Unpin,
113{
114 /// Read a protobuf message of the given type from the stream.
115 pub async fn read_message(&mut self) -> io::Result<Envelope> {
116 let mut delimiter_buf = [0; 4];
117 self.byte_stream.read_exact(&mut delimiter_buf).await?;
118 let message_len = u32::from_be_bytes(delimiter_buf) as usize;
119 self.buffer.resize(message_len, 0);
120 self.byte_stream.read_exact(&mut self.buffer).await?;
121 Ok(Envelope::decode(self.buffer.as_slice())?)
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use std::{
129 pin::Pin,
130 task::{Context, Poll},
131 };
132
133 #[test]
134 fn test_round_trip_message() {
135 smol::block_on(async {
136 let byte_stream = ChunkedStream {
137 bytes: Vec::new(),
138 read_offset: 0,
139 chunk_size: 3,
140 };
141
142 let message1 = Auth {
143 user_id: 5,
144 access_token: "the-access-token".into(),
145 }
146 .into_envelope(3, None, None);
147
148 let message2 = OpenBuffer {
149 worktree_id: 0,
150 id: 1,
151 }
152 .into_envelope(5, None, None);
153
154 let mut message_stream = MessageStream::new(byte_stream);
155 message_stream.write_message(&message1).await.unwrap();
156 message_stream.write_message(&message2).await.unwrap();
157 let decoded_message1 = message_stream.read_message().await.unwrap();
158 let decoded_message2 = message_stream.read_message().await.unwrap();
159 assert_eq!(decoded_message1, message1);
160 assert_eq!(decoded_message2, message2);
161 });
162 }
163
164 struct ChunkedStream {
165 bytes: Vec<u8>,
166 read_offset: usize,
167 chunk_size: usize,
168 }
169
170 impl AsyncWrite for ChunkedStream {
171 fn poll_write(
172 mut self: Pin<&mut Self>,
173 _: &mut Context<'_>,
174 buf: &[u8],
175 ) -> Poll<io::Result<usize>> {
176 let bytes_written = buf.len().min(self.chunk_size);
177 self.bytes.extend_from_slice(&buf[0..bytes_written]);
178 Poll::Ready(Ok(bytes_written))
179 }
180
181 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
182 Poll::Ready(Ok(()))
183 }
184
185 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
186 Poll::Ready(Ok(()))
187 }
188 }
189
190 impl AsyncRead for ChunkedStream {
191 fn poll_read(
192 mut self: Pin<&mut Self>,
193 _: &mut Context<'_>,
194 buf: &mut [u8],
195 ) -> Poll<io::Result<usize>> {
196 let bytes_read = buf
197 .len()
198 .min(self.chunk_size)
199 .min(self.bytes.len() - self.read_offset);
200 let end_offset = self.read_offset + bytes_read;
201 buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
202 self.read_offset = end_offset;
203 Poll::Ready(Ok(bytes_read))
204 }
205 }
206}