1use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
2use prost::Message;
3use std::{convert::TryInto, io};
4
5include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
6
7pub trait EnvelopedMessage: Sized + Send + 'static {
8 fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
9 fn matches_envelope(envelope: &Envelope) -> bool;
10 fn from_envelope(envelope: Envelope) -> Option<Self>;
11}
12
13pub trait RequestMessage: EnvelopedMessage {
14 type Response: EnvelopedMessage;
15}
16
17macro_rules! message {
18 ($name:ident) => {
19 impl EnvelopedMessage for $name {
20 fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
21 Envelope {
22 id,
23 responding_to,
24 payload: Some(envelope::Payload::$name(self)),
25 }
26 }
27
28 fn matches_envelope(envelope: &Envelope) -> bool {
29 matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
30 }
31
32 fn from_envelope(envelope: Envelope) -> Option<Self> {
33 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
34 Some(msg)
35 } else {
36 None
37 }
38 }
39 }
40 };
41}
42
43macro_rules! request_message {
44 ($req:ident, $resp:ident) => {
45 message!($req);
46 message!($resp);
47 impl RequestMessage for $req {
48 type Response = $resp;
49 }
50 };
51}
52
53request_message!(Auth, AuthResponse);
54request_message!(ShareWorktree, ShareWorktreeResponse);
55request_message!(OpenWorktree, OpenWorktreeResponse);
56request_message!(OpenBuffer, OpenBufferResponse);
57
58/// A stream of protobuf messages.
59pub struct MessageStream<T> {
60 byte_stream: T,
61 buffer: Vec<u8>,
62}
63
64impl<T> MessageStream<T> {
65 pub fn new(byte_stream: T) -> Self {
66 Self {
67 byte_stream,
68 buffer: Default::default(),
69 }
70 }
71
72 pub fn inner_mut(&mut self) -> &mut T {
73 &mut self.byte_stream
74 }
75}
76
77impl<T> MessageStream<T>
78where
79 T: AsyncWrite + Unpin,
80{
81 /// Write a given protobuf message to the stream.
82 pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
83 let message_len: u32 = message
84 .encoded_len()
85 .try_into()
86 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
87 self.buffer.clear();
88 self.buffer.extend_from_slice(&message_len.to_be_bytes());
89 message.encode(&mut self.buffer)?;
90 self.byte_stream.write_all(&self.buffer).await
91 }
92}
93
94impl<T> MessageStream<T>
95where
96 T: AsyncRead + Unpin,
97{
98 /// Read a protobuf message of the given type from the stream.
99 pub async fn read_message(&mut self) -> io::Result<Envelope> {
100 let mut delimiter_buf = [0; 4];
101 self.byte_stream.read_exact(&mut delimiter_buf).await?;
102 let message_len = u32::from_be_bytes(delimiter_buf) as usize;
103 self.buffer.resize(message_len, 0);
104 self.byte_stream.read_exact(&mut self.buffer).await?;
105 Ok(Envelope::decode(self.buffer.as_slice())?)
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use std::{
113 pin::Pin,
114 task::{Context, Poll},
115 };
116
117 #[test]
118 fn test_round_trip_message() {
119 smol::block_on(async {
120 let byte_stream = ChunkedStream {
121 bytes: Vec::new(),
122 read_offset: 0,
123 chunk_size: 3,
124 };
125
126 let message1 = Auth {
127 user_id: 5,
128 access_token: "the-access-token".into(),
129 }
130 .into_envelope(3, None);
131
132 let message2 = ShareWorktree {
133 worktree: Some(Worktree {
134 paths: vec![b"ok".to_vec()],
135 }),
136 }
137 .into_envelope(5, None);
138
139 let mut message_stream = MessageStream::new(byte_stream);
140 message_stream.write_message(&message1).await.unwrap();
141 message_stream.write_message(&message2).await.unwrap();
142 let decoded_message1 = message_stream.read_message().await.unwrap();
143 let decoded_message2 = message_stream.read_message().await.unwrap();
144 assert_eq!(decoded_message1, message1);
145 assert_eq!(decoded_message2, message2);
146 });
147 }
148
149 struct ChunkedStream {
150 bytes: Vec<u8>,
151 read_offset: usize,
152 chunk_size: usize,
153 }
154
155 impl AsyncWrite for ChunkedStream {
156 fn poll_write(
157 mut self: Pin<&mut Self>,
158 _: &mut Context<'_>,
159 buf: &[u8],
160 ) -> Poll<io::Result<usize>> {
161 let bytes_written = buf.len().min(self.chunk_size);
162 self.bytes.extend_from_slice(&buf[0..bytes_written]);
163 Poll::Ready(Ok(bytes_written))
164 }
165
166 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
167 Poll::Ready(Ok(()))
168 }
169
170 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
171 Poll::Ready(Ok(()))
172 }
173 }
174
175 impl AsyncRead for ChunkedStream {
176 fn poll_read(
177 mut self: Pin<&mut Self>,
178 _: &mut Context<'_>,
179 buf: &mut [u8],
180 ) -> Poll<io::Result<usize>> {
181 let bytes_read = buf
182 .len()
183 .min(self.chunk_size)
184 .min(self.bytes.len() - self.read_offset);
185 let end_offset = self.read_offset + bytes_read;
186 buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
187 self.read_offset = end_offset;
188 Poll::Ready(Ok(bytes_read))
189 }
190 }
191}