1use futures_io::{AsyncRead, AsyncWrite};
2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
3use prost::Message;
4use std::{convert::TryInto, io};
5
6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
7
8pub trait EnvelopedMessage: Sized + Send + 'static {
9 fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
10 fn matches_envelope(envelope: &Envelope) -> bool;
11 fn from_envelope(envelope: Envelope) -> Option<Self>;
12}
13
14pub trait RequestMessage: EnvelopedMessage {
15 type Response: EnvelopedMessage;
16}
17
18macro_rules! message {
19 ($name:ident) => {
20 impl EnvelopedMessage for $name {
21 fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
22 Envelope {
23 id,
24 responding_to,
25 payload: Some(envelope::Payload::$name(self)),
26 }
27 }
28
29 fn matches_envelope(envelope: &Envelope) -> bool {
30 matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
31 }
32
33 fn from_envelope(envelope: Envelope) -> Option<Self> {
34 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
35 Some(msg)
36 } else {
37 None
38 }
39 }
40 }
41 };
42}
43
44macro_rules! request_message {
45 ($req:ident, $resp:ident) => {
46 message!($req);
47 message!($resp);
48 impl RequestMessage for $req {
49 type Response = $resp;
50 }
51 };
52}
53
54request_message!(Auth, AuthResponse);
55request_message!(ShareWorktree, ShareWorktreeResponse);
56request_message!(OpenWorktree, OpenWorktreeResponse);
57request_message!(OpenBuffer, OpenBufferResponse);
58
59/// A stream of protobuf messages.
60pub struct MessageStream<T> {
61 byte_stream: T,
62 buffer: Vec<u8>,
63}
64
65impl<T> MessageStream<T> {
66 pub fn new(byte_stream: T) -> Self {
67 Self {
68 byte_stream,
69 buffer: Default::default(),
70 }
71 }
72
73 pub fn inner_mut(&mut self) -> &mut T {
74 &mut self.byte_stream
75 }
76}
77
78impl<T> MessageStream<T>
79where
80 T: AsyncWrite + Unpin,
81{
82 /// Write a given protobuf message to the stream.
83 pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
84 let message_len: u32 = message
85 .encoded_len()
86 .try_into()
87 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
88 self.buffer.clear();
89 self.buffer.extend_from_slice(&message_len.to_be_bytes());
90 message.encode(&mut self.buffer)?;
91 self.byte_stream.write_all(&self.buffer).await
92 }
93}
94
95impl<T> MessageStream<T>
96where
97 T: AsyncRead + Unpin,
98{
99 /// Read a protobuf message of the given type from the stream.
100 pub async fn read_message(&mut self) -> futures_io::Result<Envelope> {
101 let mut delimiter_buf = [0; 4];
102 self.byte_stream.read_exact(&mut delimiter_buf).await?;
103 let message_len = u32::from_be_bytes(delimiter_buf) as usize;
104 self.buffer.resize(message_len, 0);
105 self.byte_stream.read_exact(&mut self.buffer).await?;
106 Ok(Envelope::decode(self.buffer.as_slice())?)
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use std::{
114 pin::Pin,
115 task::{Context, Poll},
116 };
117
118 #[test]
119 fn test_round_trip_message() {
120 smol::block_on(async {
121 let byte_stream = ChunkedStream {
122 bytes: Vec::new(),
123 read_offset: 0,
124 chunk_size: 3,
125 };
126
127 let message1 = Auth {
128 user_id: 5,
129 access_token: "the-access-token".into(),
130 }
131 .into_envelope(3, None);
132
133 let message2 = ShareWorktree {
134 worktree: Some(Worktree {
135 paths: vec![b"ok".to_vec()],
136 }),
137 }
138 .into_envelope(5, None);
139
140 let mut message_stream = MessageStream::new(byte_stream);
141 message_stream.write_message(&message1).await.unwrap();
142 message_stream.write_message(&message2).await.unwrap();
143 let decoded_message1 = message_stream.read_message().await.unwrap();
144 let decoded_message2 = message_stream.read_message().await.unwrap();
145 assert_eq!(decoded_message1, message1);
146 assert_eq!(decoded_message2, message2);
147 });
148 }
149
150 struct ChunkedStream {
151 bytes: Vec<u8>,
152 read_offset: usize,
153 chunk_size: usize,
154 }
155
156 impl AsyncWrite for ChunkedStream {
157 fn poll_write(
158 mut self: Pin<&mut Self>,
159 _: &mut Context<'_>,
160 buf: &[u8],
161 ) -> Poll<io::Result<usize>> {
162 let bytes_written = buf.len().min(self.chunk_size);
163 self.bytes.extend_from_slice(&buf[0..bytes_written]);
164 Poll::Ready(Ok(bytes_written))
165 }
166
167 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
168 Poll::Ready(Ok(()))
169 }
170
171 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
172 Poll::Ready(Ok(()))
173 }
174 }
175
176 impl AsyncRead for ChunkedStream {
177 fn poll_read(
178 mut self: Pin<&mut Self>,
179 _: &mut Context<'_>,
180 buf: &mut [u8],
181 ) -> Poll<io::Result<usize>> {
182 let bytes_read = buf
183 .len()
184 .min(self.chunk_size)
185 .min(self.bytes.len() - self.read_offset);
186 let end_offset = self.read_offset + bytes_read;
187 buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
188 self.read_offset = end_offset;
189 Poll::Ready(Ok(bytes_read))
190 }
191 }
192}