1use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
2use prost::Message;
3use std::{convert::TryInto, io};
4
5include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
6
7pub trait EnvelopedMessage: Sized + Send + 'static {
8 const NAME: &'static str;
9 fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
10 fn matches_envelope(envelope: &Envelope) -> bool;
11 fn from_envelope(envelope: Envelope) -> Option<Self>;
12}
13
14pub trait RequestMessage: EnvelopedMessage {
15 type Response: EnvelopedMessage;
16}
17
18macro_rules! message {
19 ($name:ident) => {
20 impl EnvelopedMessage for $name {
21 const NAME: &'static str = std::stringify!($name);
22
23 fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
24 Envelope {
25 id,
26 responding_to,
27 payload: Some(envelope::Payload::$name(self)),
28 }
29 }
30
31 fn matches_envelope(envelope: &Envelope) -> bool {
32 matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
33 }
34
35 fn from_envelope(envelope: Envelope) -> Option<Self> {
36 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
37 Some(msg)
38 } else {
39 None
40 }
41 }
42 }
43 };
44}
45
46macro_rules! request_message {
47 ($req:ident, $resp:ident) => {
48 message!($req);
49 message!($resp);
50 impl RequestMessage for $req {
51 type Response = $resp;
52 }
53 };
54}
55
56request_message!(Auth, AuthResponse);
57request_message!(ShareWorktree, ShareWorktreeResponse);
58request_message!(OpenWorktree, OpenWorktreeResponse);
59request_message!(OpenBuffer, OpenBufferResponse);
60
61/// A stream of protobuf messages.
62pub struct MessageStream<T> {
63 byte_stream: T,
64 buffer: Vec<u8>,
65}
66
67impl<T> MessageStream<T> {
68 pub fn new(byte_stream: T) -> Self {
69 Self {
70 byte_stream,
71 buffer: Default::default(),
72 }
73 }
74
75 pub fn inner_mut(&mut self) -> &mut T {
76 &mut self.byte_stream
77 }
78}
79
80impl<T> MessageStream<T>
81where
82 T: AsyncWrite + Unpin,
83{
84 /// Write a given protobuf message to the stream.
85 pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
86 let message_len: u32 = message
87 .encoded_len()
88 .try_into()
89 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
90 self.buffer.clear();
91 self.buffer.extend_from_slice(&message_len.to_be_bytes());
92 message.encode(&mut self.buffer)?;
93 self.byte_stream.write_all(&self.buffer).await
94 }
95}
96
97impl<T> MessageStream<T>
98where
99 T: AsyncRead + Unpin,
100{
101 /// Read a protobuf message of the given type from the stream.
102 pub async fn read_message(&mut self) -> io::Result<Envelope> {
103 let mut delimiter_buf = [0; 4];
104 self.byte_stream.read_exact(&mut delimiter_buf).await?;
105 let message_len = u32::from_be_bytes(delimiter_buf) as usize;
106 self.buffer.resize(message_len, 0);
107 self.byte_stream.read_exact(&mut self.buffer).await?;
108 Ok(Envelope::decode(self.buffer.as_slice())?)
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use std::{
116 pin::Pin,
117 task::{Context, Poll},
118 };
119
120 #[test]
121 fn test_round_trip_message() {
122 smol::block_on(async {
123 let byte_stream = ChunkedStream {
124 bytes: Vec::new(),
125 read_offset: 0,
126 chunk_size: 3,
127 };
128
129 let message1 = Auth {
130 user_id: 5,
131 access_token: "the-access-token".into(),
132 }
133 .into_envelope(3, None);
134
135 let message2 = OpenBuffer {
136 worktree_id: 1,
137 path: "path".to_string(),
138 }
139 .into_envelope(5, None);
140
141 let mut message_stream = MessageStream::new(byte_stream);
142 message_stream.write_message(&message1).await.unwrap();
143 message_stream.write_message(&message2).await.unwrap();
144 let decoded_message1 = message_stream.read_message().await.unwrap();
145 let decoded_message2 = message_stream.read_message().await.unwrap();
146 assert_eq!(decoded_message1, message1);
147 assert_eq!(decoded_message2, message2);
148 });
149 }
150
151 struct ChunkedStream {
152 bytes: Vec<u8>,
153 read_offset: usize,
154 chunk_size: usize,
155 }
156
157 impl AsyncWrite for ChunkedStream {
158 fn poll_write(
159 mut self: Pin<&mut Self>,
160 _: &mut Context<'_>,
161 buf: &[u8],
162 ) -> Poll<io::Result<usize>> {
163 let bytes_written = buf.len().min(self.chunk_size);
164 self.bytes.extend_from_slice(&buf[0..bytes_written]);
165 Poll::Ready(Ok(bytes_written))
166 }
167
168 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
169 Poll::Ready(Ok(()))
170 }
171
172 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
173 Poll::Ready(Ok(()))
174 }
175 }
176
177 impl AsyncRead for ChunkedStream {
178 fn poll_read(
179 mut self: Pin<&mut Self>,
180 _: &mut Context<'_>,
181 buf: &mut [u8],
182 ) -> Poll<io::Result<usize>> {
183 let bytes_read = buf
184 .len()
185 .min(self.chunk_size)
186 .min(self.bytes.len() - self.read_offset);
187 let end_offset = self.read_offset + bytes_read;
188 buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
189 self.read_offset = end_offset;
190 Poll::Ready(Ok(bytes_read))
191 }
192 }
193}