1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Result};
3use async_lock::{Mutex, RwLock};
4use futures::{
5 future::{BoxFuture, Either},
6 AsyncRead, AsyncWrite, FutureExt,
7};
8use postage::{
9 barrier, mpsc, oneshot,
10 prelude::{Sink, Stream},
11};
12use std::{
13 any::TypeId,
14 collections::{HashMap, HashSet},
15 fmt,
16 future::Future,
17 pin::Pin,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22};
23
24type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
25type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
26
27#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
28pub struct ConnectionId(u32);
29
30struct Connection {
31 writer: Mutex<MessageStream<BoxedWriter>>,
32 reader: Mutex<MessageStream<BoxedReader>>,
33 response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
34 next_message_id: AtomicU32,
35}
36
37type MessageHandler = Box<
38 dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
39>;
40
41pub struct TypedEnvelope<T> {
42 pub id: u32,
43 pub connection_id: ConnectionId,
44 pub payload: T,
45}
46
47pub struct Peer {
48 connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
49 connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
50 message_handlers: RwLock<Vec<MessageHandler>>,
51 handler_types: Mutex<HashSet<TypeId>>,
52 next_connection_id: AtomicU32,
53}
54
55impl Peer {
56 pub fn new() -> Arc<Self> {
57 Arc::new(Self {
58 connections: Default::default(),
59 connection_close_barriers: Default::default(),
60 message_handlers: Default::default(),
61 handler_types: Default::default(),
62 next_connection_id: Default::default(),
63 })
64 }
65
66 pub async fn add_message_handler<T: EnvelopedMessage>(
67 &self,
68 ) -> mpsc::Receiver<TypedEnvelope<T>> {
69 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
70 panic!("duplicate handler type");
71 }
72
73 let (tx, rx) = mpsc::channel(256);
74 self.message_handlers
75 .write()
76 .await
77 .push(Box::new(move |envelope, connection_id| {
78 if envelope.as_ref().map_or(false, T::matches_envelope) {
79 let envelope = Option::take(envelope).unwrap();
80 let mut tx = tx.clone();
81 Some(
82 async move {
83 tx.send(TypedEnvelope {
84 id: envelope.id,
85 connection_id,
86 payload: T::from_envelope(envelope).unwrap(),
87 })
88 .await
89 .is_err()
90 }
91 .boxed(),
92 )
93 } else {
94 None
95 }
96 }));
97 rx
98 }
99
100 pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
101 where
102 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
103 {
104 let connection_id = ConnectionId(
105 self.next_connection_id
106 .fetch_add(1, atomic::Ordering::SeqCst),
107 );
108 self.connections.write().await.insert(
109 connection_id,
110 Arc::new(Connection {
111 reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
112 writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
113 response_channels: Default::default(),
114 next_message_id: Default::default(),
115 }),
116 );
117 connection_id
118 }
119
120 pub async fn disconnect(&self, connection_id: ConnectionId) {
121 self.connections.write().await.remove(&connection_id);
122 self.connection_close_barriers
123 .write()
124 .await
125 .remove(&connection_id);
126 }
127
128 pub fn handle_messages(
129 self: &Arc<Self>,
130 connection_id: ConnectionId,
131 ) -> impl Future<Output = Result<()>> + 'static {
132 let (close_tx, mut close_rx) = barrier::channel();
133 let this = self.clone();
134 async move {
135 this.connection_close_barriers
136 .write()
137 .await
138 .insert(connection_id, close_tx);
139 let connection = this.connection(connection_id).await?;
140 let closed = close_rx.recv();
141 futures::pin_mut!(closed);
142
143 loop {
144 let mut reader = connection.reader.lock().await;
145 let read_message = reader.read_message();
146 futures::pin_mut!(read_message);
147
148 match futures::future::select(read_message, &mut closed).await {
149 Either::Left((Ok(incoming), _)) => {
150 if let Some(responding_to) = incoming.responding_to {
151 let channel = connection
152 .response_channels
153 .lock()
154 .await
155 .remove(&responding_to);
156 if let Some(mut tx) = channel {
157 tx.send(incoming).await.ok();
158 } else {
159 log::warn!(
160 "received RPC response to unknown request {}",
161 responding_to
162 );
163 }
164 } else {
165 let mut envelope = Some(incoming);
166 let mut handler_index = None;
167 let mut handler_was_dropped = false;
168 for (i, handler) in
169 this.message_handlers.read().await.iter().enumerate()
170 {
171 if let Some(future) = handler(&mut envelope, connection_id) {
172 handler_was_dropped = future.await;
173 handler_index = Some(i);
174 break;
175 }
176 }
177
178 if let Some(handler_index) = handler_index {
179 if handler_was_dropped {
180 drop(this.message_handlers.write().await.remove(handler_index));
181 }
182 } else {
183 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
184 }
185 }
186 }
187 Either::Left((Err(error), _)) => {
188 log::warn!("received invalid RPC message: {}", error);
189 Err(error)?;
190 }
191 Either::Right(_) => return Ok(()),
192 }
193 }
194 }
195 }
196
197 pub async fn receive<M: EnvelopedMessage>(
198 self: &Arc<Self>,
199 connection_id: ConnectionId,
200 ) -> Result<TypedEnvelope<M>> {
201 let connection = self.connection(connection_id).await?;
202 let envelope = connection.reader.lock().await.read_message().await?;
203 let id = envelope.id;
204 let payload =
205 M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
206 Ok(TypedEnvelope {
207 id,
208 connection_id,
209 payload,
210 })
211 }
212
213 pub fn request<T: RequestMessage>(
214 self: &Arc<Self>,
215 connection_id: ConnectionId,
216 req: T,
217 ) -> impl Future<Output = Result<T::Response>> {
218 let this = self.clone();
219 let (tx, mut rx) = oneshot::channel();
220 async move {
221 let connection = this.connection(connection_id).await?;
222 let message_id = connection
223 .next_message_id
224 .fetch_add(1, atomic::Ordering::SeqCst);
225 connection
226 .response_channels
227 .lock()
228 .await
229 .insert(message_id, tx);
230 connection
231 .writer
232 .lock()
233 .await
234 .write_message(&req.into_envelope(message_id, None))
235 .await?;
236 let response = rx
237 .recv()
238 .await
239 .expect("response channel was unexpectedly dropped");
240 T::Response::from_envelope(response)
241 .ok_or_else(|| anyhow!("received response of the wrong type"))
242 }
243 }
244
245 pub fn send<T: EnvelopedMessage>(
246 self: &Arc<Self>,
247 connection_id: ConnectionId,
248 message: T,
249 ) -> impl Future<Output = Result<()>> {
250 let this = self.clone();
251 async move {
252 let connection = this.connection(connection_id).await?;
253 let message_id = connection
254 .next_message_id
255 .fetch_add(1, atomic::Ordering::SeqCst);
256 connection
257 .writer
258 .lock()
259 .await
260 .write_message(&message.into_envelope(message_id, None))
261 .await?;
262 Ok(())
263 }
264 }
265
266 pub fn respond<T: RequestMessage>(
267 self: &Arc<Self>,
268 request: TypedEnvelope<T>,
269 response: T::Response,
270 ) -> impl Future<Output = Result<()>> {
271 let this = self.clone();
272 async move {
273 let connection = this.connection(request.connection_id).await?;
274 let message_id = connection
275 .next_message_id
276 .fetch_add(1, atomic::Ordering::SeqCst);
277 connection
278 .writer
279 .lock()
280 .await
281 .write_message(&response.into_envelope(message_id, Some(request.id)))
282 .await?;
283 Ok(())
284 }
285 }
286
287 async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
288 Ok(self
289 .connections
290 .read()
291 .await
292 .get(&id)
293 .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
294 .clone())
295 }
296}
297
298impl fmt::Display for ConnectionId {
299 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
300 self.0.fmt(f)
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use smol::{
308 io::AsyncWriteExt,
309 net::unix::{UnixListener, UnixStream},
310 };
311 use std::io;
312 use tempdir::TempDir;
313
314 #[test]
315 fn test_request_response() {
316 smol::block_on(async move {
317 // create socket
318 let socket_dir_path = TempDir::new("test-request-response").unwrap();
319 let socket_path = socket_dir_path.path().join("test.sock");
320 let listener = UnixListener::bind(&socket_path).unwrap();
321
322 // create 2 clients connected to 1 server
323 let server = Peer::new();
324 let client1 = Peer::new();
325 let client2 = Peer::new();
326 let client1_conn_id = client1
327 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
328 .await;
329 let client2_conn_id = client2
330 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
331 .await;
332 let server_conn_id1 = server
333 .add_connection(listener.accept().await.unwrap().0)
334 .await;
335 let server_conn_id2 = server
336 .add_connection(listener.accept().await.unwrap().0)
337 .await;
338 smol::spawn(client1.handle_messages(client1_conn_id)).detach();
339 smol::spawn(client2.handle_messages(client2_conn_id)).detach();
340 smol::spawn(server.handle_messages(server_conn_id1)).detach();
341 smol::spawn(server.handle_messages(server_conn_id2)).detach();
342
343 // define the expected requests and responses
344 let request1 = proto::Auth {
345 user_id: 1,
346 access_token: "token-1".to_string(),
347 };
348 let response1 = proto::AuthResponse {
349 credentials_valid: true,
350 };
351 let request2 = proto::Auth {
352 user_id: 2,
353 access_token: "token-2".to_string(),
354 };
355 let response2 = proto::AuthResponse {
356 credentials_valid: false,
357 };
358 let request3 = proto::OpenBuffer {
359 worktree_id: 102,
360 path: "path/two".to_string(),
361 };
362 let response3 = proto::OpenBufferResponse {
363 buffer: Some(proto::Buffer {
364 id: 1001,
365 path: "path/two".to_string(),
366 content: "path/two content".to_string(),
367 history: vec![],
368 }),
369 };
370 let request4 = proto::OpenBuffer {
371 worktree_id: 101,
372 path: "path/one".to_string(),
373 };
374 let response4 = proto::OpenBufferResponse {
375 buffer: Some(proto::Buffer {
376 id: 1002,
377 path: "path/one".to_string(),
378 content: "path/one content".to_string(),
379 history: vec![],
380 }),
381 };
382
383 // on the server, respond to two requests for each client
384 let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
385 let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
386 let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
387 smol::spawn({
388 let request1 = request1.clone();
389 let request2 = request2.clone();
390 let request3 = request3.clone();
391 let request4 = request4.clone();
392 let response1 = response1.clone();
393 let response2 = response2.clone();
394 let response3 = response3.clone();
395 let response4 = response4.clone();
396 async move {
397 let msg = auth_rx.recv().await.unwrap();
398 assert_eq!(msg.payload, request1);
399 server.respond(msg, response1.clone()).await.unwrap();
400
401 let msg = auth_rx.recv().await.unwrap();
402 assert_eq!(msg.payload, request2.clone());
403 server.respond(msg, response2.clone()).await.unwrap();
404
405 let msg = open_buffer_rx.recv().await.unwrap();
406 assert_eq!(msg.payload, request3.clone());
407 server.respond(msg, response3.clone()).await.unwrap();
408
409 let msg = open_buffer_rx.recv().await.unwrap();
410 assert_eq!(msg.payload, request4.clone());
411 server.respond(msg, response4.clone()).await.unwrap();
412
413 server_done_tx.send(()).await.unwrap();
414 }
415 })
416 .detach();
417
418 assert_eq!(
419 client1.request(client1_conn_id, request1).await.unwrap(),
420 response1
421 );
422 assert_eq!(
423 client2.request(client2_conn_id, request2).await.unwrap(),
424 response2
425 );
426 assert_eq!(
427 client2.request(client2_conn_id, request3).await.unwrap(),
428 response3
429 );
430 assert_eq!(
431 client1.request(client1_conn_id, request4).await.unwrap(),
432 response4
433 );
434
435 client1.disconnect(client1_conn_id).await;
436 client2.disconnect(client1_conn_id).await;
437
438 server_done_rx.recv().await.unwrap();
439 });
440 }
441
442 #[test]
443 fn test_disconnect() {
444 smol::block_on(async move {
445 let socket_dir_path = TempDir::new("drop-client").unwrap();
446 let socket_path = socket_dir_path.path().join(".sock");
447 let listener = UnixListener::bind(&socket_path).unwrap();
448 let client_conn = UnixStream::connect(&socket_path).await.unwrap();
449 let (mut server_conn, _) = listener.accept().await.unwrap();
450
451 let client = Peer::new();
452 let connection_id = client.add_connection(client_conn).await;
453 let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
454 barrier::channel();
455 let handle_messages = client.handle_messages(connection_id);
456 smol::spawn(async move {
457 handle_messages.await.ok();
458 incoming_messages_ended_tx.send(()).await.unwrap();
459 })
460 .detach();
461 client.disconnect(connection_id).await;
462
463 incoming_messages_ended_rx.recv().await;
464
465 let err = server_conn.write(&[]).await.unwrap_err();
466 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
467 });
468 }
469
470 #[test]
471 fn test_io_error() {
472 smol::block_on(async move {
473 let socket_dir_path = TempDir::new("io-error").unwrap();
474 let socket_path = socket_dir_path.path().join(".sock");
475 let _listener = UnixListener::bind(&socket_path).unwrap();
476 let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
477 client_conn.close().await.unwrap();
478
479 let client = Peer::new();
480 let connection_id = client.add_connection(client_conn).await;
481 smol::spawn(client.handle_messages(connection_id)).detach();
482
483 let err = client
484 .request(
485 connection_id,
486 proto::Auth {
487 user_id: 42,
488 access_token: "token".to_string(),
489 },
490 )
491 .await
492 .unwrap_err();
493 assert_eq!(
494 err.downcast_ref::<io::Error>().unwrap().kind(),
495 io::ErrorKind::BrokenPipe
496 );
497 });
498 }
499}