1use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt as _};
2use prost::Message;
3use std::{convert::TryInto, io};
4
5include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
6
7pub trait EnvelopedMessage: Sized + Send + 'static {
8 const NAME: &'static str;
9 fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope;
10 fn matches_envelope(envelope: &Envelope) -> bool;
11 fn from_envelope(envelope: Envelope) -> Option<Self>;
12}
13
14pub trait RequestMessage: EnvelopedMessage {
15 type Response: EnvelopedMessage;
16}
17
18macro_rules! message {
19 ($name:ident) => {
20 impl EnvelopedMessage for $name {
21 const NAME: &'static str = std::stringify!($name);
22
23 fn into_envelope(self, id: u32, responding_to: Option<u32>) -> Envelope {
24 Envelope {
25 id,
26 responding_to,
27 payload: Some(envelope::Payload::$name(self)),
28 }
29 }
30
31 fn matches_envelope(envelope: &Envelope) -> bool {
32 matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
33 }
34
35 fn from_envelope(envelope: Envelope) -> Option<Self> {
36 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
37 Some(msg)
38 } else {
39 None
40 }
41 }
42 }
43 };
44}
45
46macro_rules! request_message {
47 ($req:ident, $resp:ident) => {
48 message!($req);
49 message!($resp);
50 impl RequestMessage for $req {
51 type Response = $resp;
52 }
53 };
54}
55
56request_message!(Auth, AuthResponse);
57request_message!(ShareWorktree, ShareWorktreeResponse);
58request_message!(OpenWorktree, OpenWorktreeResponse);
59request_message!(OpenFile, OpenFileResponse);
60message!(CloseFile);
61request_message!(OpenBuffer, OpenBufferResponse);
62
63/// A stream of protobuf messages.
64pub struct MessageStream<T> {
65 byte_stream: T,
66 buffer: Vec<u8>,
67}
68
69impl<T> MessageStream<T> {
70 pub fn new(byte_stream: T) -> Self {
71 Self {
72 byte_stream,
73 buffer: Default::default(),
74 }
75 }
76
77 pub fn inner_mut(&mut self) -> &mut T {
78 &mut self.byte_stream
79 }
80}
81
82impl<T> MessageStream<T>
83where
84 T: AsyncWrite + Unpin,
85{
86 /// Write a given protobuf message to the stream.
87 pub async fn write_message(&mut self, message: &Envelope) -> io::Result<()> {
88 let message_len: u32 = message
89 .encoded_len()
90 .try_into()
91 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "message is too large"))?;
92 self.buffer.clear();
93 self.buffer.extend_from_slice(&message_len.to_be_bytes());
94 message.encode(&mut self.buffer)?;
95 self.byte_stream.write_all(&self.buffer).await
96 }
97}
98
99impl<T> MessageStream<T>
100where
101 T: AsyncRead + Unpin,
102{
103 /// Read a protobuf message of the given type from the stream.
104 pub async fn read_message(&mut self) -> io::Result<Envelope> {
105 let mut delimiter_buf = [0; 4];
106 self.byte_stream.read_exact(&mut delimiter_buf).await?;
107 let message_len = u32::from_be_bytes(delimiter_buf) as usize;
108 self.buffer.resize(message_len, 0);
109 self.byte_stream.read_exact(&mut self.buffer).await?;
110 Ok(Envelope::decode(self.buffer.as_slice())?)
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use std::{
118 pin::Pin,
119 task::{Context, Poll},
120 };
121
122 #[test]
123 fn test_round_trip_message() {
124 smol::block_on(async {
125 let byte_stream = ChunkedStream {
126 bytes: Vec::new(),
127 read_offset: 0,
128 chunk_size: 3,
129 };
130
131 let message1 = Auth {
132 user_id: 5,
133 access_token: "the-access-token".into(),
134 }
135 .into_envelope(3, None);
136
137 let message2 = OpenBuffer {
138 worktree_id: 1,
139 path: "path".to_string(),
140 }
141 .into_envelope(5, None);
142
143 let mut message_stream = MessageStream::new(byte_stream);
144 message_stream.write_message(&message1).await.unwrap();
145 message_stream.write_message(&message2).await.unwrap();
146 let decoded_message1 = message_stream.read_message().await.unwrap();
147 let decoded_message2 = message_stream.read_message().await.unwrap();
148 assert_eq!(decoded_message1, message1);
149 assert_eq!(decoded_message2, message2);
150 });
151 }
152
153 struct ChunkedStream {
154 bytes: Vec<u8>,
155 read_offset: usize,
156 chunk_size: usize,
157 }
158
159 impl AsyncWrite for ChunkedStream {
160 fn poll_write(
161 mut self: Pin<&mut Self>,
162 _: &mut Context<'_>,
163 buf: &[u8],
164 ) -> Poll<io::Result<usize>> {
165 let bytes_written = buf.len().min(self.chunk_size);
166 self.bytes.extend_from_slice(&buf[0..bytes_written]);
167 Poll::Ready(Ok(bytes_written))
168 }
169
170 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
171 Poll::Ready(Ok(()))
172 }
173
174 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
175 Poll::Ready(Ok(()))
176 }
177 }
178
179 impl AsyncRead for ChunkedStream {
180 fn poll_read(
181 mut self: Pin<&mut Self>,
182 _: &mut Context<'_>,
183 buf: &mut [u8],
184 ) -> Poll<io::Result<usize>> {
185 let bytes_read = buf
186 .len()
187 .min(self.chunk_size)
188 .min(self.bytes.len() - self.read_offset);
189 let end_offset = self.read_offset + bytes_read;
190 buf[0..bytes_read].copy_from_slice(&self.bytes[self.read_offset..end_offset]);
191 self.read_offset = end_offset;
192 Poll::Ready(Ok(bytes_read))
193 }
194 }
195}