1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Result};
3use async_lock::{Mutex, RwLock};
4use futures::{
5 future::{BoxFuture, Either},
6 AsyncRead, AsyncWrite, FutureExt,
7};
8use postage::{
9 barrier, mpsc, oneshot,
10 prelude::{Sink, Stream},
11};
12use std::{
13 any::TypeId,
14 collections::{HashMap, HashSet},
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 pin::Pin,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23};
24
25type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
26type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
27
28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct ConnectionId(pub u32);
30
31#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
32pub struct PeerId(pub u32);
33
34struct Connection {
35 writer: Mutex<MessageStream<BoxedWriter>>,
36 reader: Mutex<MessageStream<BoxedReader>>,
37 response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
38 next_message_id: AtomicU32,
39}
40
41type MessageHandler = Box<
42 dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
43>;
44
45#[derive(Clone, Copy)]
46pub struct Receipt<T> {
47 sender_id: ConnectionId,
48 message_id: u32,
49 payload_type: PhantomData<T>,
50}
51
52pub struct TypedEnvelope<T> {
53 pub sender_id: ConnectionId,
54 original_sender_id: Option<PeerId>,
55 pub message_id: u32,
56 pub payload: T,
57}
58
59impl<T> TypedEnvelope<T> {
60 pub fn original_sender_id(&self) -> Result<PeerId> {
61 self.original_sender_id
62 .ok_or_else(|| anyhow!("missing original_sender_id"))
63 }
64}
65
66impl<T: RequestMessage> TypedEnvelope<T> {
67 pub fn receipt(&self) -> Receipt<T> {
68 Receipt {
69 sender_id: self.sender_id,
70 message_id: self.message_id,
71 payload_type: PhantomData,
72 }
73 }
74}
75
76pub struct Peer {
77 connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
78 connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
79 message_handlers: RwLock<Vec<MessageHandler>>,
80 handler_types: Mutex<HashSet<TypeId>>,
81 next_connection_id: AtomicU32,
82}
83
84impl Peer {
85 pub fn new() -> Arc<Self> {
86 Arc::new(Self {
87 connections: Default::default(),
88 connection_close_barriers: Default::default(),
89 message_handlers: Default::default(),
90 handler_types: Default::default(),
91 next_connection_id: Default::default(),
92 })
93 }
94
95 pub async fn add_message_handler<T: EnvelopedMessage>(
96 &self,
97 ) -> mpsc::Receiver<TypedEnvelope<T>> {
98 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
99 panic!("duplicate handler type");
100 }
101
102 let (tx, rx) = mpsc::channel(256);
103 self.message_handlers
104 .write()
105 .await
106 .push(Box::new(move |envelope, connection_id| {
107 if envelope.as_ref().map_or(false, T::matches_envelope) {
108 let envelope = Option::take(envelope).unwrap();
109 let mut tx = tx.clone();
110 Some(
111 async move {
112 tx.send(TypedEnvelope {
113 sender_id: connection_id,
114 original_sender_id: envelope.original_sender_id.map(PeerId),
115 message_id: envelope.id,
116 payload: T::from_envelope(envelope).unwrap(),
117 })
118 .await
119 .is_err()
120 }
121 .boxed(),
122 )
123 } else {
124 None
125 }
126 }));
127 rx
128 }
129
130 pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
131 where
132 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
133 {
134 let connection_id = ConnectionId(
135 self.next_connection_id
136 .fetch_add(1, atomic::Ordering::SeqCst),
137 );
138 self.connections.write().await.insert(
139 connection_id,
140 Arc::new(Connection {
141 reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
142 writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
143 response_channels: Default::default(),
144 next_message_id: Default::default(),
145 }),
146 );
147 connection_id
148 }
149
150 pub async fn disconnect(&self, connection_id: ConnectionId) {
151 self.connections.write().await.remove(&connection_id);
152 self.connection_close_barriers
153 .write()
154 .await
155 .remove(&connection_id);
156 }
157
158 pub async fn reset(&self) {
159 self.connections.write().await.clear();
160 self.connection_close_barriers.write().await.clear();
161 self.handler_types.lock().await.clear();
162 self.message_handlers.write().await.clear();
163 }
164
165 pub fn handle_messages(
166 self: &Arc<Self>,
167 connection_id: ConnectionId,
168 ) -> impl Future<Output = Result<()>> + 'static {
169 let (close_tx, mut close_rx) = barrier::channel();
170 let this = self.clone();
171 async move {
172 this.connection_close_barriers
173 .write()
174 .await
175 .insert(connection_id, close_tx);
176 let connection = this.connection(connection_id).await?;
177 let closed = close_rx.recv();
178 futures::pin_mut!(closed);
179
180 loop {
181 let mut reader = connection.reader.lock().await;
182 let read_message = reader.read_message();
183 futures::pin_mut!(read_message);
184
185 match futures::future::select(read_message, &mut closed).await {
186 Either::Left((Ok(incoming), _)) => {
187 if let Some(responding_to) = incoming.responding_to {
188 let channel = connection
189 .response_channels
190 .lock()
191 .await
192 .remove(&responding_to);
193 if let Some(mut tx) = channel {
194 tx.send(incoming).await.ok();
195 } else {
196 log::warn!(
197 "received RPC response to unknown request {}",
198 responding_to
199 );
200 }
201 } else {
202 let mut envelope = Some(incoming);
203 let mut handler_index = None;
204 let mut handler_was_dropped = false;
205 for (i, handler) in
206 this.message_handlers.read().await.iter().enumerate()
207 {
208 if let Some(future) = handler(&mut envelope, connection_id) {
209 handler_was_dropped = future.await;
210 handler_index = Some(i);
211 break;
212 }
213 }
214
215 if let Some(handler_index) = handler_index {
216 if handler_was_dropped {
217 drop(this.message_handlers.write().await.remove(handler_index));
218 }
219 } else {
220 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
221 }
222 }
223 }
224 Either::Left((Err(error), _)) => {
225 log::warn!("received invalid RPC message: {}", error);
226 Err(error)?;
227 }
228 Either::Right(_) => return Ok(()),
229 }
230 }
231 }
232 }
233
234 pub async fn receive<M: EnvelopedMessage>(
235 self: &Arc<Self>,
236 connection_id: ConnectionId,
237 ) -> Result<TypedEnvelope<M>> {
238 let connection = self.connection(connection_id).await?;
239 let envelope = connection.reader.lock().await.read_message().await?;
240 let original_sender_id = envelope.original_sender_id;
241 let message_id = envelope.id;
242 let payload =
243 M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
244 Ok(TypedEnvelope {
245 sender_id: connection_id,
246 original_sender_id: original_sender_id.map(PeerId),
247 message_id,
248 payload,
249 })
250 }
251
252 pub fn request<T: RequestMessage>(
253 self: &Arc<Self>,
254 receiver_id: ConnectionId,
255 request: T,
256 ) -> impl Future<Output = Result<T::Response>> {
257 self.request_internal(None, receiver_id, request)
258 }
259
260 pub fn forward_request<T: RequestMessage>(
261 self: &Arc<Self>,
262 sender_id: ConnectionId,
263 receiver_id: ConnectionId,
264 request: T,
265 ) -> impl Future<Output = Result<T::Response>> {
266 self.request_internal(Some(sender_id), receiver_id, request)
267 }
268
269 pub fn request_internal<T: RequestMessage>(
270 self: &Arc<Self>,
271 original_sender_id: Option<ConnectionId>,
272 receiver_id: ConnectionId,
273 request: T,
274 ) -> impl Future<Output = Result<T::Response>> {
275 let this = self.clone();
276 let (tx, mut rx) = oneshot::channel();
277 async move {
278 let connection = this.connection(receiver_id).await?;
279 let message_id = connection
280 .next_message_id
281 .fetch_add(1, atomic::Ordering::SeqCst);
282 connection
283 .response_channels
284 .lock()
285 .await
286 .insert(message_id, tx);
287 connection
288 .writer
289 .lock()
290 .await
291 .write_message(&request.into_envelope(
292 message_id,
293 None,
294 original_sender_id.map(|id| id.0),
295 ))
296 .await?;
297 let response = rx
298 .recv()
299 .await
300 .expect("response channel was unexpectedly dropped");
301 T::Response::from_envelope(response)
302 .ok_or_else(|| anyhow!("received response of the wrong type"))
303 }
304 }
305
306 pub fn send<T: EnvelopedMessage>(
307 self: &Arc<Self>,
308 connection_id: ConnectionId,
309 message: T,
310 ) -> impl Future<Output = Result<()>> {
311 let this = self.clone();
312 async move {
313 let connection = this.connection(connection_id).await?;
314 let message_id = connection
315 .next_message_id
316 .fetch_add(1, atomic::Ordering::SeqCst);
317 connection
318 .writer
319 .lock()
320 .await
321 .write_message(&message.into_envelope(message_id, None, None))
322 .await?;
323 Ok(())
324 }
325 }
326
327 pub fn forward_send<T: EnvelopedMessage>(
328 self: &Arc<Self>,
329 sender_id: ConnectionId,
330 receiver_id: ConnectionId,
331 message: T,
332 ) -> impl Future<Output = Result<()>> {
333 let this = self.clone();
334 async move {
335 let connection = this.connection(receiver_id).await?;
336 let message_id = connection
337 .next_message_id
338 .fetch_add(1, atomic::Ordering::SeqCst);
339 connection
340 .writer
341 .lock()
342 .await
343 .write_message(&message.into_envelope(message_id, None, Some(sender_id.0)))
344 .await?;
345 Ok(())
346 }
347 }
348
349 pub fn respond<T: RequestMessage>(
350 self: &Arc<Self>,
351 receipt: Receipt<T>,
352 response: T::Response,
353 ) -> impl Future<Output = Result<()>> {
354 let this = self.clone();
355 async move {
356 let connection = this.connection(receipt.sender_id).await?;
357 let message_id = connection
358 .next_message_id
359 .fetch_add(1, atomic::Ordering::SeqCst);
360 connection
361 .writer
362 .lock()
363 .await
364 .write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
365 .await?;
366 Ok(())
367 }
368 }
369
370 async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
371 Ok(self
372 .connections
373 .read()
374 .await
375 .get(&id)
376 .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
377 .clone())
378 }
379}
380
381impl fmt::Display for ConnectionId {
382 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
383 self.0.fmt(f)
384 }
385}
386
387impl fmt::Display for PeerId {
388 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
389 self.0.fmt(f)
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use smol::{
397 io::AsyncWriteExt,
398 net::unix::{UnixListener, UnixStream},
399 };
400 use std::io;
401 use tempdir::TempDir;
402
403 #[test]
404 fn test_request_response() {
405 smol::block_on(async move {
406 // create socket
407 let socket_dir_path = TempDir::new("test-request-response").unwrap();
408 let socket_path = socket_dir_path.path().join("test.sock");
409 let listener = UnixListener::bind(&socket_path).unwrap();
410
411 // create 2 clients connected to 1 server
412 let server = Peer::new();
413 let client1 = Peer::new();
414 let client2 = Peer::new();
415 let client1_conn_id = client1
416 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
417 .await;
418 let client2_conn_id = client2
419 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
420 .await;
421 let server_conn_id1 = server
422 .add_connection(listener.accept().await.unwrap().0)
423 .await;
424 let server_conn_id2 = server
425 .add_connection(listener.accept().await.unwrap().0)
426 .await;
427 smol::spawn(client1.handle_messages(client1_conn_id)).detach();
428 smol::spawn(client2.handle_messages(client2_conn_id)).detach();
429 smol::spawn(server.handle_messages(server_conn_id1)).detach();
430 smol::spawn(server.handle_messages(server_conn_id2)).detach();
431
432 // define the expected requests and responses
433 let request1 = proto::Auth {
434 user_id: 1,
435 access_token: "token-1".to_string(),
436 };
437 let response1 = proto::AuthResponse {
438 credentials_valid: true,
439 };
440 let request2 = proto::Auth {
441 user_id: 2,
442 access_token: "token-2".to_string(),
443 };
444 let response2 = proto::AuthResponse {
445 credentials_valid: false,
446 };
447 let request3 = proto::OpenBuffer {
448 worktree_id: 1,
449 path: "path/two".to_string(),
450 };
451 let response3 = proto::OpenBufferResponse {
452 buffer: Some(proto::Buffer {
453 id: 2,
454 content: "path/two content".to_string(),
455 history: vec![],
456 }),
457 };
458 let request4 = proto::OpenBuffer {
459 worktree_id: 2,
460 path: "path/one".to_string(),
461 };
462 let response4 = proto::OpenBufferResponse {
463 buffer: Some(proto::Buffer {
464 id: 1,
465 content: "path/one content".to_string(),
466 history: vec![],
467 }),
468 };
469
470 // on the server, respond to two requests for each client
471 let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
472 let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
473 let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
474 smol::spawn({
475 let request1 = request1.clone();
476 let request2 = request2.clone();
477 let request3 = request3.clone();
478 let request4 = request4.clone();
479 let response1 = response1.clone();
480 let response2 = response2.clone();
481 let response3 = response3.clone();
482 let response4 = response4.clone();
483 async move {
484 let msg = auth_rx.recv().await.unwrap();
485 assert_eq!(msg.payload, request1);
486 server
487 .respond(msg.receipt(), response1.clone())
488 .await
489 .unwrap();
490
491 let msg = auth_rx.recv().await.unwrap();
492 assert_eq!(msg.payload, request2.clone());
493 server
494 .respond(msg.receipt(), response2.clone())
495 .await
496 .unwrap();
497
498 let msg = open_buffer_rx.recv().await.unwrap();
499 assert_eq!(msg.payload, request3.clone());
500 server
501 .respond(msg.receipt(), response3.clone())
502 .await
503 .unwrap();
504
505 let msg = open_buffer_rx.recv().await.unwrap();
506 assert_eq!(msg.payload, request4.clone());
507 server
508 .respond(msg.receipt(), response4.clone())
509 .await
510 .unwrap();
511
512 server_done_tx.send(()).await.unwrap();
513 }
514 })
515 .detach();
516
517 assert_eq!(
518 client1.request(client1_conn_id, request1).await.unwrap(),
519 response1
520 );
521 assert_eq!(
522 client2.request(client2_conn_id, request2).await.unwrap(),
523 response2
524 );
525 assert_eq!(
526 client2.request(client2_conn_id, request3).await.unwrap(),
527 response3
528 );
529 assert_eq!(
530 client1.request(client1_conn_id, request4).await.unwrap(),
531 response4
532 );
533
534 client1.disconnect(client1_conn_id).await;
535 client2.disconnect(client1_conn_id).await;
536
537 server_done_rx.recv().await.unwrap();
538 });
539 }
540
541 #[test]
542 fn test_disconnect() {
543 smol::block_on(async move {
544 let socket_dir_path = TempDir::new("drop-client").unwrap();
545 let socket_path = socket_dir_path.path().join(".sock");
546 let listener = UnixListener::bind(&socket_path).unwrap();
547 let client_conn = UnixStream::connect(&socket_path).await.unwrap();
548 let (mut server_conn, _) = listener.accept().await.unwrap();
549
550 let client = Peer::new();
551 let connection_id = client.add_connection(client_conn).await;
552 let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
553 barrier::channel();
554 let handle_messages = client.handle_messages(connection_id);
555 smol::spawn(async move {
556 handle_messages.await.ok();
557 incoming_messages_ended_tx.send(()).await.unwrap();
558 })
559 .detach();
560 client.disconnect(connection_id).await;
561
562 incoming_messages_ended_rx.recv().await;
563
564 let err = server_conn.write(&[]).await.unwrap_err();
565 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
566 });
567 }
568
569 #[test]
570 fn test_io_error() {
571 smol::block_on(async move {
572 let socket_dir_path = TempDir::new("io-error").unwrap();
573 let socket_path = socket_dir_path.path().join(".sock");
574 let _listener = UnixListener::bind(&socket_path).unwrap();
575 let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
576 client_conn.close().await.unwrap();
577
578 let client = Peer::new();
579 let connection_id = client.add_connection(client_conn).await;
580 smol::spawn(client.handle_messages(connection_id)).detach();
581
582 let err = client
583 .request(
584 connection_id,
585 proto::Auth {
586 user_id: 42,
587 access_token: "token".to_string(),
588 },
589 )
590 .await
591 .unwrap_err();
592 assert_eq!(
593 err.downcast_ref::<io::Error>().unwrap().kind(),
594 io::ErrorKind::BrokenPipe
595 );
596 });
597 }
598}