1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Result};
3use async_lock::{Mutex, RwLock};
4use futures::{
5 future::{BoxFuture, Either},
6 AsyncRead, AsyncWrite, FutureExt,
7};
8use postage::{
9 barrier, mpsc, oneshot,
10 prelude::{Sink, Stream},
11};
12use std::{
13 any::TypeId,
14 collections::{HashMap, HashSet},
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 pin::Pin,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23};
24
25type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
26type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
27
28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct ConnectionId(u32);
30
31#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
32pub struct PeerId(u32);
33
34struct Connection {
35 writer: Mutex<MessageStream<BoxedWriter>>,
36 reader: Mutex<MessageStream<BoxedReader>>,
37 response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
38 next_message_id: AtomicU32,
39}
40
41type MessageHandler = Box<
42 dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
43>;
44
45#[derive(Clone, Copy)]
46pub struct Receipt<T> {
47 sender_id: ConnectionId,
48 message_id: u32,
49 payload_type: PhantomData<T>,
50}
51
52pub struct TypedEnvelope<T> {
53 pub sender_id: ConnectionId,
54 pub original_sender_id: Option<PeerId>,
55 pub message_id: u32,
56 pub payload: T,
57}
58
59impl<T: RequestMessage> TypedEnvelope<T> {
60 pub fn receipt(&self) -> Receipt<T> {
61 Receipt {
62 sender_id: self.sender_id,
63 message_id: self.message_id,
64 payload_type: PhantomData,
65 }
66 }
67}
68
69pub struct Peer {
70 connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
71 connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
72 message_handlers: RwLock<Vec<MessageHandler>>,
73 handler_types: Mutex<HashSet<TypeId>>,
74 next_connection_id: AtomicU32,
75}
76
77impl Peer {
78 pub fn new() -> Arc<Self> {
79 Arc::new(Self {
80 connections: Default::default(),
81 connection_close_barriers: Default::default(),
82 message_handlers: Default::default(),
83 handler_types: Default::default(),
84 next_connection_id: Default::default(),
85 })
86 }
87
88 pub async fn add_message_handler<T: EnvelopedMessage>(
89 &self,
90 ) -> mpsc::Receiver<TypedEnvelope<T>> {
91 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
92 panic!("duplicate handler type");
93 }
94
95 let (tx, rx) = mpsc::channel(256);
96 self.message_handlers
97 .write()
98 .await
99 .push(Box::new(move |envelope, connection_id| {
100 if envelope.as_ref().map_or(false, T::matches_envelope) {
101 let envelope = Option::take(envelope).unwrap();
102 let mut tx = tx.clone();
103 Some(
104 async move {
105 tx.send(TypedEnvelope {
106 sender_id: connection_id,
107 original_sender_id: envelope.original_sender_id.map(PeerId),
108 message_id: envelope.id,
109 payload: T::from_envelope(envelope).unwrap(),
110 })
111 .await
112 .is_err()
113 }
114 .boxed(),
115 )
116 } else {
117 None
118 }
119 }));
120 rx
121 }
122
123 pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
124 where
125 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
126 {
127 let connection_id = ConnectionId(
128 self.next_connection_id
129 .fetch_add(1, atomic::Ordering::SeqCst),
130 );
131 self.connections.write().await.insert(
132 connection_id,
133 Arc::new(Connection {
134 reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
135 writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
136 response_channels: Default::default(),
137 next_message_id: Default::default(),
138 }),
139 );
140 connection_id
141 }
142
143 pub async fn disconnect(&self, connection_id: ConnectionId) {
144 self.connections.write().await.remove(&connection_id);
145 self.connection_close_barriers
146 .write()
147 .await
148 .remove(&connection_id);
149 }
150
151 pub fn handle_messages(
152 self: &Arc<Self>,
153 connection_id: ConnectionId,
154 ) -> impl Future<Output = Result<()>> + 'static {
155 let (close_tx, mut close_rx) = barrier::channel();
156 let this = self.clone();
157 async move {
158 this.connection_close_barriers
159 .write()
160 .await
161 .insert(connection_id, close_tx);
162 let connection = this.connection(connection_id).await?;
163 let closed = close_rx.recv();
164 futures::pin_mut!(closed);
165
166 loop {
167 let mut reader = connection.reader.lock().await;
168 let read_message = reader.read_message();
169 futures::pin_mut!(read_message);
170
171 match futures::future::select(read_message, &mut closed).await {
172 Either::Left((Ok(incoming), _)) => {
173 if let Some(responding_to) = incoming.responding_to {
174 let channel = connection
175 .response_channels
176 .lock()
177 .await
178 .remove(&responding_to);
179 if let Some(mut tx) = channel {
180 tx.send(incoming).await.ok();
181 } else {
182 log::warn!(
183 "received RPC response to unknown request {}",
184 responding_to
185 );
186 }
187 } else {
188 let mut envelope = Some(incoming);
189 let mut handler_index = None;
190 let mut handler_was_dropped = false;
191 for (i, handler) in
192 this.message_handlers.read().await.iter().enumerate()
193 {
194 if let Some(future) = handler(&mut envelope, connection_id) {
195 handler_was_dropped = future.await;
196 handler_index = Some(i);
197 break;
198 }
199 }
200
201 if let Some(handler_index) = handler_index {
202 if handler_was_dropped {
203 drop(this.message_handlers.write().await.remove(handler_index));
204 }
205 } else {
206 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
207 }
208 }
209 }
210 Either::Left((Err(error), _)) => {
211 log::warn!("received invalid RPC message: {}", error);
212 Err(error)?;
213 }
214 Either::Right(_) => return Ok(()),
215 }
216 }
217 }
218 }
219
220 pub async fn receive<M: EnvelopedMessage>(
221 self: &Arc<Self>,
222 connection_id: ConnectionId,
223 ) -> Result<TypedEnvelope<M>> {
224 let connection = self.connection(connection_id).await?;
225 let envelope = connection.reader.lock().await.read_message().await?;
226 let original_sender_id = envelope.original_sender_id;
227 let message_id = envelope.id;
228 let payload =
229 M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
230 Ok(TypedEnvelope {
231 sender_id: connection_id,
232 original_sender_id: original_sender_id.map(PeerId),
233 message_id,
234 payload,
235 })
236 }
237
238 pub fn request<T: RequestMessage>(
239 self: &Arc<Self>,
240 receiver_id: ConnectionId,
241 request: T,
242 ) -> impl Future<Output = Result<T::Response>> {
243 self.request_internal(None, receiver_id, request)
244 }
245
246 pub fn forward_request<T: RequestMessage>(
247 self: &Arc<Self>,
248 sender_id: ConnectionId,
249 receiver_id: ConnectionId,
250 request: T,
251 ) -> impl Future<Output = Result<T::Response>> {
252 self.request_internal(Some(sender_id), receiver_id, request)
253 }
254
255 pub fn request_internal<T: RequestMessage>(
256 self: &Arc<Self>,
257 original_sender_id: Option<ConnectionId>,
258 receiver_id: ConnectionId,
259 request: T,
260 ) -> impl Future<Output = Result<T::Response>> {
261 let this = self.clone();
262 let (tx, mut rx) = oneshot::channel();
263 async move {
264 let connection = this.connection(receiver_id).await?;
265 let message_id = connection
266 .next_message_id
267 .fetch_add(1, atomic::Ordering::SeqCst);
268 connection
269 .response_channels
270 .lock()
271 .await
272 .insert(message_id, tx);
273 connection
274 .writer
275 .lock()
276 .await
277 .write_message(&request.into_envelope(
278 message_id,
279 None,
280 original_sender_id.map(|id| id.0),
281 ))
282 .await?;
283 let response = rx
284 .recv()
285 .await
286 .expect("response channel was unexpectedly dropped");
287 T::Response::from_envelope(response)
288 .ok_or_else(|| anyhow!("received response of the wrong type"))
289 }
290 }
291
292 pub fn send<T: EnvelopedMessage>(
293 self: &Arc<Self>,
294 connection_id: ConnectionId,
295 message: T,
296 ) -> impl Future<Output = Result<()>> {
297 let this = self.clone();
298 async move {
299 let connection = this.connection(connection_id).await?;
300 let message_id = connection
301 .next_message_id
302 .fetch_add(1, atomic::Ordering::SeqCst);
303 connection
304 .writer
305 .lock()
306 .await
307 .write_message(&message.into_envelope(message_id, None, None))
308 .await?;
309 Ok(())
310 }
311 }
312
313 pub fn forward_send<T: EnvelopedMessage>(
314 self: &Arc<Self>,
315 sender_id: ConnectionId,
316 receiver_id: ConnectionId,
317 message: T,
318 ) -> impl Future<Output = Result<()>> {
319 let this = self.clone();
320 async move {
321 let connection = this.connection(receiver_id).await?;
322 let message_id = connection
323 .next_message_id
324 .fetch_add(1, atomic::Ordering::SeqCst);
325 connection
326 .writer
327 .lock()
328 .await
329 .write_message(&message.into_envelope(message_id, None, Some(sender_id.0)))
330 .await?;
331 Ok(())
332 }
333 }
334
335 pub fn respond<T: RequestMessage>(
336 self: &Arc<Self>,
337 receipt: Receipt<T>,
338 response: T::Response,
339 ) -> impl Future<Output = Result<()>> {
340 let this = self.clone();
341 async move {
342 let connection = this.connection(receipt.sender_id).await?;
343 let message_id = connection
344 .next_message_id
345 .fetch_add(1, atomic::Ordering::SeqCst);
346 connection
347 .writer
348 .lock()
349 .await
350 .write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
351 .await?;
352 Ok(())
353 }
354 }
355
356 async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
357 Ok(self
358 .connections
359 .read()
360 .await
361 .get(&id)
362 .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
363 .clone())
364 }
365}
366
367impl fmt::Display for ConnectionId {
368 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
369 self.0.fmt(f)
370 }
371}
372
373impl fmt::Display for PeerId {
374 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
375 self.0.fmt(f)
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use smol::{
383 io::AsyncWriteExt,
384 net::unix::{UnixListener, UnixStream},
385 };
386 use std::io;
387 use tempdir::TempDir;
388
389 #[test]
390 fn test_request_response() {
391 smol::block_on(async move {
392 // create socket
393 let socket_dir_path = TempDir::new("test-request-response").unwrap();
394 let socket_path = socket_dir_path.path().join("test.sock");
395 let listener = UnixListener::bind(&socket_path).unwrap();
396
397 // create 2 clients connected to 1 server
398 let server = Peer::new();
399 let client1 = Peer::new();
400 let client2 = Peer::new();
401 let client1_conn_id = client1
402 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
403 .await;
404 let client2_conn_id = client2
405 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
406 .await;
407 let server_conn_id1 = server
408 .add_connection(listener.accept().await.unwrap().0)
409 .await;
410 let server_conn_id2 = server
411 .add_connection(listener.accept().await.unwrap().0)
412 .await;
413 smol::spawn(client1.handle_messages(client1_conn_id)).detach();
414 smol::spawn(client2.handle_messages(client2_conn_id)).detach();
415 smol::spawn(server.handle_messages(server_conn_id1)).detach();
416 smol::spawn(server.handle_messages(server_conn_id2)).detach();
417
418 // define the expected requests and responses
419 let request1 = proto::Auth {
420 user_id: 1,
421 access_token: "token-1".to_string(),
422 };
423 let response1 = proto::AuthResponse {
424 credentials_valid: true,
425 };
426 let request2 = proto::Auth {
427 user_id: 2,
428 access_token: "token-2".to_string(),
429 };
430 let response2 = proto::AuthResponse {
431 credentials_valid: false,
432 };
433 let request3 = proto::OpenBuffer {
434 worktree_id: 1,
435 id: 2,
436 };
437 let response3 = proto::OpenBufferResponse {
438 buffer: Some(proto::Buffer {
439 content: "path/two content".to_string(),
440 history: vec![],
441 }),
442 };
443 let request4 = proto::OpenBuffer {
444 worktree_id: 2,
445 id: 1,
446 };
447 let response4 = proto::OpenBufferResponse {
448 buffer: Some(proto::Buffer {
449 content: "path/one content".to_string(),
450 history: vec![],
451 }),
452 };
453
454 // on the server, respond to two requests for each client
455 let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
456 let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
457 let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
458 smol::spawn({
459 let request1 = request1.clone();
460 let request2 = request2.clone();
461 let request3 = request3.clone();
462 let request4 = request4.clone();
463 let response1 = response1.clone();
464 let response2 = response2.clone();
465 let response3 = response3.clone();
466 let response4 = response4.clone();
467 async move {
468 let msg = auth_rx.recv().await.unwrap();
469 assert_eq!(msg.payload, request1);
470 server
471 .respond(msg.receipt(), response1.clone())
472 .await
473 .unwrap();
474
475 let msg = auth_rx.recv().await.unwrap();
476 assert_eq!(msg.payload, request2.clone());
477 server
478 .respond(msg.receipt(), response2.clone())
479 .await
480 .unwrap();
481
482 let msg = open_buffer_rx.recv().await.unwrap();
483 assert_eq!(msg.payload, request3.clone());
484 server
485 .respond(msg.receipt(), response3.clone())
486 .await
487 .unwrap();
488
489 let msg = open_buffer_rx.recv().await.unwrap();
490 assert_eq!(msg.payload, request4.clone());
491 server
492 .respond(msg.receipt(), response4.clone())
493 .await
494 .unwrap();
495
496 server_done_tx.send(()).await.unwrap();
497 }
498 })
499 .detach();
500
501 assert_eq!(
502 client1.request(client1_conn_id, request1).await.unwrap(),
503 response1
504 );
505 assert_eq!(
506 client2.request(client2_conn_id, request2).await.unwrap(),
507 response2
508 );
509 assert_eq!(
510 client2.request(client2_conn_id, request3).await.unwrap(),
511 response3
512 );
513 assert_eq!(
514 client1.request(client1_conn_id, request4).await.unwrap(),
515 response4
516 );
517
518 client1.disconnect(client1_conn_id).await;
519 client2.disconnect(client1_conn_id).await;
520
521 server_done_rx.recv().await.unwrap();
522 });
523 }
524
525 #[test]
526 fn test_disconnect() {
527 smol::block_on(async move {
528 let socket_dir_path = TempDir::new("drop-client").unwrap();
529 let socket_path = socket_dir_path.path().join(".sock");
530 let listener = UnixListener::bind(&socket_path).unwrap();
531 let client_conn = UnixStream::connect(&socket_path).await.unwrap();
532 let (mut server_conn, _) = listener.accept().await.unwrap();
533
534 let client = Peer::new();
535 let connection_id = client.add_connection(client_conn).await;
536 let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
537 barrier::channel();
538 let handle_messages = client.handle_messages(connection_id);
539 smol::spawn(async move {
540 handle_messages.await.ok();
541 incoming_messages_ended_tx.send(()).await.unwrap();
542 })
543 .detach();
544 client.disconnect(connection_id).await;
545
546 incoming_messages_ended_rx.recv().await;
547
548 let err = server_conn.write(&[]).await.unwrap_err();
549 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
550 });
551 }
552
553 #[test]
554 fn test_io_error() {
555 smol::block_on(async move {
556 let socket_dir_path = TempDir::new("io-error").unwrap();
557 let socket_path = socket_dir_path.path().join(".sock");
558 let _listener = UnixListener::bind(&socket_path).unwrap();
559 let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
560 client_conn.close().await.unwrap();
561
562 let client = Peer::new();
563 let connection_id = client.add_connection(client_conn).await;
564 smol::spawn(client.handle_messages(connection_id)).detach();
565
566 let err = client
567 .request(
568 connection_id,
569 proto::Auth {
570 user_id: 42,
571 access_token: "token".to_string(),
572 },
573 )
574 .await
575 .unwrap_err();
576 assert_eq!(
577 err.downcast_ref::<io::Error>().unwrap().kind(),
578 io::ErrorKind::BrokenPipe
579 );
580 });
581 }
582}