1use futures_io::{AsyncRead, AsyncWrite};
2use futures_lite::{AsyncReadExt, AsyncWriteExt as _};
3use prost::Message;
4use std::{convert::TryInto, io};
5
6include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
7
8pub trait EnvelopedMessage: Sized + Send + 'static {
9 fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
10 fn from_envelope(envelope: Envelope) -> Option<Self>;
11}
12
13pub trait RequestMessage: EnvelopedMessage {
14 type Response: EnvelopedMessage;
15}
16
17macro_rules! message {
18 ($name:ident) => {
19 impl EnvelopedMessage for $name {
20 fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
21 Envelope {
22 id,
23 responding_to,
24 payload: Some(envelope::Payload::$name(self)),
25 }
26 }
27
28 fn from_envelope(envelope: Envelope) -> Option<Self> {
29 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
30 Some(msg)
31 } else {
32 None
33 }
34 }
35 }
36 };
37}
38
39macro_rules! request_message {
40 ($req:ident, $resp:ident) => {
41 message!($req);
42 message!($resp);
43 impl RequestMessage for $req {
44 type Response = $resp;
45 }
46 };
47}
48
49request_message!(Auth, AuthResponse);
50request_message!(ShareWorktree, ShareWorktreeResponse);
51request_message!(OpenWorktree, OpenWorktreeResponse);
52request_message!(OpenBuffer, OpenBufferResponse);
53
54/// A stream of protobuf messages.
55pub struct MessageStream<T> {
56 byte_stream: T,
57 buffer: Vec<u8>,
58}
59
60impl<T> MessageStream<T> {
61 pub fn new(byte_stream: T) -> Self {
62 Self {
63 byte_stream,
64 buffer: Default::default(),
65 }
66 }
67
68 pub fn inner_mut(&mut self) -> &mut T {
69 &mut self.byte_stream
70 }
71}
72
73impl<T> MessageStream<T>
74where
75 T: AsyncWrite + Unpin,
76{
77 /// Write a given protobuf message to the stream.
78 pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
79 let message_len: u32 = message
80 .encoded_len()
81 .try_into()
82 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
83 self.buffer.clear();
84 self.buffer.extend_from_slice(&message_len.to_be_bytes());
85 message.encode(&mut self.buffer)?;
86 self.byte_stream.write_all(&self.buffer).await
87 }
88}
89
90impl<T> MessageStream<T>
91where
92 T: AsyncRead + Unpin,
93{
94 /// Read a protobuf message of the given type from the stream.
95 pub async fn read_message(&mut self) -> futures_io::Result<Envelope> {
96 let mut delimiter_buf = [0; 4];
97 self.byte_stream.read_exact(&mut delimiter_buf).await?;
98 let message_len = u32::from_be_bytes(delimiter_buf) as usize;
99 self.buffer.resize(message_len, 0);
100 self.byte_stream.read_exact(&mut self.buffer).await?;
101 Ok(Envelope::decode(self.buffer.as_slice())?)
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use std::{
109 pin::Pin,
110 task::{Context, Poll},
111 };
112
113 #[test]
114 fn test_round_trip_message() {
115 smol::block_on(async {
116 let byte_stream = ChunkedStream {
117 bytes: Vec::new(),
118 read_offset: 0,
119 chunk_size: 3,
120 };
121
122 let message1 = Auth {
123 user_id: 5,
124 access_token: "the-access-token".into(),
125 }
126 .into_envelope(3, None);
127
128 let message2 = ShareWorktree {
129 worktree: Some(Worktree {
130 paths: vec![b"ok".to_vec()],
131 }),
132 }
133 .into_envelope(5, None);
134
135 let mut message_stream = MessageStream::new(byte_stream);
136 message_stream.write_message(&message1).await.unwrap();
137 message_stream.write_message(&message2).await.unwrap();
138 let decoded_message1 = message_stream.read_message().await.unwrap();
139 let decoded_message2 = message_stream.read_message().await.unwrap();
140 assert_eq!(decoded_message1, message1);
141 assert_eq!(decoded_message2, message2);
142 });
143 }
144
145 struct ChunkedStream {
146 bytes: Vec<u8>,
147 read_offset: usize,
148 chunk_size: usize,
149 }
150
151 impl AsyncWrite for ChunkedStream {
152 fn poll_write(
153 mut self: Pin<&mut Self>,
154 _: &mut Context<'_>,
155 buf: &[u8],
156 ) -> Poll<io::Result<usize>> {
157 let bytes_written = buf.len().min(self.chunk_size);
158 self.bytes.extend_from_slice(&buf[0..bytes_written]);
159 Poll::Ready(Ok(bytes_written))
160 }
161
162 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
163 Poll::Ready(Ok(()))
164 }
165
166 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
167 Poll::Ready(Ok(()))
168 }
169 }
170
171 impl AsyncRead for ChunkedStream {
172 fn poll_read(
173 mut self: Pin<&mut Self>,
174 _: &mut Context<'_>,
175 buf: &mut [u8],
176 ) -> Poll<io::Result<usize>> {
177 let bytes_read = buf
178 .len()
179 .min(self.chunk_size)
180 .min(self.bytes.len() - self.read_offset);
181 let end_offset = self.read_offset + bytes_read;
182 buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
183 self.read_offset = end_offset;
184 Poll::Ready(Ok(bytes_read))
185 }
186 }
187}