1use super::{ConnectionId, PeerId, TypedEnvelope};
2use anyhow::{anyhow, Result};
3use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
4use futures::{SinkExt as _, StreamExt as _};
5use prost::Message;
6use std::any::Any;
7use std::sync::Arc;
8use std::{
9 io,
10 time::{Duration, SystemTime, UNIX_EPOCH},
11};
12
13include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
14
15pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
16 const NAME: &'static str;
17 fn into_envelope(
18 self,
19 id: u32,
20 responding_to: Option<u32>,
21 original_sender_id: Option<u32>,
22 ) -> Envelope;
23 fn matches_envelope(envelope: &Envelope) -> bool;
24 fn from_envelope(envelope: Envelope) -> Option<Self>;
25}
26
27pub trait RequestMessage: EnvelopedMessage {
28 type Response: EnvelopedMessage;
29}
30
31macro_rules! messages {
32 ($($name:ident),*) => {
33 fn unicast_message_into_typed_envelope(sender_id: ConnectionId, envelope: &mut Envelope) -> Option<Arc<dyn Any + Send + Sync>> {
34 match &mut envelope.payload {
35 $(payload @ Some(envelope::Payload::$name(_)) => Some(Arc::new(TypedEnvelope {
36 sender_id,
37 original_sender_id: envelope.original_sender_id.map(PeerId),
38 message_id: envelope.id,
39 payload: payload.take().unwrap(),
40 })), )*
41 _ => None
42 }
43 }
44
45 $(
46 message!($name);
47 )*
48 };
49}
50
51macro_rules! request_messages {
52 ($(($request_name:ident, $response_name:ident)),*) => {
53 fn request_message_into_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Arc<dyn Any + Send + Sync>> {
54 match envelope.payload {
55 $(
56 Some(envelope::Payload::$request_name(payload)) => Some(Arc::new(TypedEnvelope {
57 sender_id,
58 original_sender_id: envelope.original_sender_id.map(PeerId),
59 message_id: envelope.id,
60 payload,
61 })),
62 Some(envelope::Payload::$response_name(payload)) => Some(Arc::new(TypedEnvelope {
63 sender_id,
64 original_sender_id: envelope.original_sender_id.map(PeerId),
65 message_id: envelope.id,
66 payload,
67 })),
68 )*
69 _ => None
70 }
71 }
72
73 $(
74 message!($request_name);
75 message!($response_name);
76 )*
77
78 $(impl RequestMessage for $request_name {
79 type Response = $response_name;
80 })*
81 };
82}
83
84macro_rules! message {
85 ($name:ident) => {
86 impl EnvelopedMessage for $name {
87 const NAME: &'static str = std::stringify!($name);
88
89 fn into_envelope(
90 self,
91 id: u32,
92 responding_to: Option<u32>,
93 original_sender_id: Option<u32>,
94 ) -> Envelope {
95 Envelope {
96 id,
97 responding_to,
98 original_sender_id,
99 payload: Some(envelope::Payload::$name(self)),
100 }
101 }
102
103 fn matches_envelope(envelope: &Envelope) -> bool {
104 matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
105 }
106
107 fn from_envelope(envelope: Envelope) -> Option<Self> {
108 if let Some(envelope::Payload::$name(msg)) = envelope.payload {
109 Some(msg)
110 } else {
111 None
112 }
113 }
114 }
115 };
116}
117
118messages!(
119 UpdateWorktree,
120 CloseWorktree,
121 CloseBuffer,
122 UpdateBuffer,
123 AddPeer,
124 RemovePeer,
125 SendChannelMessage,
126 ChannelMessageSent
127);
128
129request_messages!(
130 (Auth, AuthResponse),
131 (ShareWorktree, ShareWorktreeResponse),
132 (OpenWorktree, OpenWorktreeResponse),
133 (OpenBuffer, OpenBufferResponse),
134 (SaveBuffer, BufferSaved),
135 (GetChannels, GetChannelsResponse),
136 (JoinChannel, JoinChannelResponse),
137 (GetUsers, GetUsersResponse)
138);
139
140pub fn build_typed_envelope(
141 sender_id: ConnectionId,
142 mut envelope: Envelope,
143) -> Result<Arc<dyn Any + Send + Sync>> {
144 unicast_message_into_typed_envelope(sender_id, &mut envelope)
145 .or_else(|| request_message_into_typed_envelope(sender_id, envelope))
146 .ok_or_else(|| anyhow!("unrecognized payload type"))
147}
148
149/// A stream of protobuf messages.
150pub struct MessageStream<S> {
151 stream: S,
152}
153
154impl<S> MessageStream<S> {
155 pub fn new(stream: S) -> Self {
156 Self { stream }
157 }
158
159 pub fn inner_mut(&mut self) -> &mut S {
160 &mut self.stream
161 }
162}
163
164impl<S> MessageStream<S>
165where
166 S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
167{
168 /// Write a given protobuf message to the stream.
169 pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
170 let mut buffer = Vec::with_capacity(message.encoded_len());
171 message
172 .encode(&mut buffer)
173 .map_err(|err| io::Error::from(err))?;
174 self.stream.send(WebSocketMessage::Binary(buffer)).await?;
175 Ok(())
176 }
177}
178
179impl<S> MessageStream<S>
180where
181 S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
182{
183 /// Read a protobuf message of the given type from the stream.
184 pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
185 while let Some(bytes) = self.stream.next().await {
186 match bytes? {
187 WebSocketMessage::Binary(bytes) => {
188 let envelope = Envelope::decode(bytes.as_slice()).map_err(io::Error::from)?;
189 return Ok(envelope);
190 }
191 WebSocketMessage::Close(_) => break,
192 _ => {}
193 }
194 }
195 Err(WebSocketError::ConnectionClosed)
196 }
197}
198
199impl Into<SystemTime> for Timestamp {
200 fn into(self) -> SystemTime {
201 UNIX_EPOCH
202 .checked_add(Duration::new(self.seconds, self.nanos))
203 .unwrap()
204 }
205}
206
207impl From<SystemTime> for Timestamp {
208 fn from(time: SystemTime) -> Self {
209 let duration = time.duration_since(UNIX_EPOCH).unwrap();
210 Self {
211 seconds: duration.as_secs(),
212 nanos: duration.subsec_nanos(),
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use crate::test;
221
222 #[test]
223 fn test_round_trip_message() {
224 smol::block_on(async {
225 let stream = test::Channel::new();
226 let message1 = Auth {
227 user_id: 5,
228 access_token: "the-access-token".into(),
229 }
230 .into_envelope(3, None, None);
231
232 let message2 = OpenBuffer {
233 worktree_id: 0,
234 path: "some/path".to_string(),
235 }
236 .into_envelope(5, None, None);
237
238 let mut message_stream = MessageStream::new(stream);
239 message_stream.write_message(&message1).await.unwrap();
240 message_stream.write_message(&message2).await.unwrap();
241 let decoded_message1 = message_stream.read_message().await.unwrap();
242 let decoded_message2 = message_stream.read_message().await.unwrap();
243 assert_eq!(decoded_message1, message1);
244 assert_eq!(decoded_message2, message2);
245 });
246 }
247}