peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{
  5    future::{BoxFuture, Either},
  6    AsyncRead, AsyncWrite, FutureExt,
  7};
  8use postage::{
  9    barrier, mpsc, oneshot,
 10    prelude::{Sink, Stream},
 11};
 12use std::{
 13    any::TypeId,
 14    collections::{HashMap, HashSet},
 15    fmt,
 16    future::Future,
 17    marker::PhantomData,
 18    pin::Pin,
 19    sync::{
 20        atomic::{self, AtomicU32},
 21        Arc,
 22    },
 23};
 24
 25type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
 26type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
 27
 28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 29pub struct ConnectionId(u32);
 30
 31#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 32pub struct PeerId(u32);
 33
 34struct Connection {
 35    writer: Mutex<MessageStream<BoxedWriter>>,
 36    reader: Mutex<MessageStream<BoxedReader>>,
 37    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
 38    next_message_id: AtomicU32,
 39}
 40
 41type MessageHandler = Box<
 42    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 43>;
 44
 45#[derive(Clone, Copy)]
 46pub struct Receipt<T> {
 47    sender_id: ConnectionId,
 48    message_id: u32,
 49    payload_type: PhantomData<T>,
 50}
 51
 52pub struct TypedEnvelope<T> {
 53    pub sender_id: ConnectionId,
 54    pub original_sender_id: Option<PeerId>,
 55    pub message_id: u32,
 56    pub payload: T,
 57}
 58
 59impl<T: RequestMessage> TypedEnvelope<T> {
 60    pub fn receipt(&self) -> Receipt<T> {
 61        Receipt {
 62            sender_id: self.sender_id,
 63            message_id: self.message_id,
 64            payload_type: PhantomData,
 65        }
 66    }
 67}
 68
 69pub struct Peer {
 70    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
 71    connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
 72    message_handlers: RwLock<Vec<MessageHandler>>,
 73    handler_types: Mutex<HashSet<TypeId>>,
 74    next_connection_id: AtomicU32,
 75}
 76
 77impl Peer {
 78    pub fn new() -> Arc<Self> {
 79        Arc::new(Self {
 80            connections: Default::default(),
 81            connection_close_barriers: Default::default(),
 82            message_handlers: Default::default(),
 83            handler_types: Default::default(),
 84            next_connection_id: Default::default(),
 85        })
 86    }
 87
 88    pub async fn add_message_handler<T: EnvelopedMessage>(
 89        &self,
 90    ) -> mpsc::Receiver<TypedEnvelope<T>> {
 91        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
 92            panic!("duplicate handler type");
 93        }
 94
 95        let (tx, rx) = mpsc::channel(256);
 96        self.message_handlers
 97            .write()
 98            .await
 99            .push(Box::new(move |envelope, connection_id| {
100                if envelope.as_ref().map_or(false, T::matches_envelope) {
101                    let envelope = Option::take(envelope).unwrap();
102                    let mut tx = tx.clone();
103                    Some(
104                        async move {
105                            tx.send(TypedEnvelope {
106                                sender_id: connection_id,
107                                original_sender_id: envelope.original_sender_id.map(PeerId),
108                                message_id: envelope.id,
109                                payload: T::from_envelope(envelope).unwrap(),
110                            })
111                            .await
112                            .is_err()
113                        }
114                        .boxed(),
115                    )
116                } else {
117                    None
118                }
119            }));
120        rx
121    }
122
123    pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
124    where
125        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
126    {
127        let connection_id = ConnectionId(
128            self.next_connection_id
129                .fetch_add(1, atomic::Ordering::SeqCst),
130        );
131        self.connections.write().await.insert(
132            connection_id,
133            Arc::new(Connection {
134                reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
135                writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
136                response_channels: Default::default(),
137                next_message_id: Default::default(),
138            }),
139        );
140        connection_id
141    }
142
143    pub async fn disconnect(&self, connection_id: ConnectionId) {
144        self.connections.write().await.remove(&connection_id);
145        self.connection_close_barriers
146            .write()
147            .await
148            .remove(&connection_id);
149    }
150
151    pub async fn reset(&self) {
152        self.connections.write().await.clear();
153        self.connection_close_barriers.write().await.clear();
154        self.handler_types.lock().await.clear();
155        self.message_handlers.write().await.clear();
156    }
157
158    pub fn handle_messages(
159        self: &Arc<Self>,
160        connection_id: ConnectionId,
161    ) -> impl Future<Output = Result<()>> + 'static {
162        let (close_tx, mut close_rx) = barrier::channel();
163        let this = self.clone();
164        async move {
165            this.connection_close_barriers
166                .write()
167                .await
168                .insert(connection_id, close_tx);
169            let connection = this.connection(connection_id).await?;
170            let closed = close_rx.recv();
171            futures::pin_mut!(closed);
172
173            loop {
174                let mut reader = connection.reader.lock().await;
175                let read_message = reader.read_message();
176                futures::pin_mut!(read_message);
177
178                match futures::future::select(read_message, &mut closed).await {
179                    Either::Left((Ok(incoming), _)) => {
180                        if let Some(responding_to) = incoming.responding_to {
181                            let channel = connection
182                                .response_channels
183                                .lock()
184                                .await
185                                .remove(&responding_to);
186                            if let Some(mut tx) = channel {
187                                tx.send(incoming).await.ok();
188                            } else {
189                                log::warn!(
190                                    "received RPC response to unknown request {}",
191                                    responding_to
192                                );
193                            }
194                        } else {
195                            let mut envelope = Some(incoming);
196                            let mut handler_index = None;
197                            let mut handler_was_dropped = false;
198                            for (i, handler) in
199                                this.message_handlers.read().await.iter().enumerate()
200                            {
201                                if let Some(future) = handler(&mut envelope, connection_id) {
202                                    handler_was_dropped = future.await;
203                                    handler_index = Some(i);
204                                    break;
205                                }
206                            }
207
208                            if let Some(handler_index) = handler_index {
209                                if handler_was_dropped {
210                                    drop(this.message_handlers.write().await.remove(handler_index));
211                                }
212                            } else {
213                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
214                            }
215                        }
216                    }
217                    Either::Left((Err(error), _)) => {
218                        log::warn!("received invalid RPC message: {}", error);
219                        Err(error)?;
220                    }
221                    Either::Right(_) => return Ok(()),
222                }
223            }
224        }
225    }
226
227    pub async fn receive<M: EnvelopedMessage>(
228        self: &Arc<Self>,
229        connection_id: ConnectionId,
230    ) -> Result<TypedEnvelope<M>> {
231        let connection = self.connection(connection_id).await?;
232        let envelope = connection.reader.lock().await.read_message().await?;
233        let original_sender_id = envelope.original_sender_id;
234        let message_id = envelope.id;
235        let payload =
236            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
237        Ok(TypedEnvelope {
238            sender_id: connection_id,
239            original_sender_id: original_sender_id.map(PeerId),
240            message_id,
241            payload,
242        })
243    }
244
245    pub fn request<T: RequestMessage>(
246        self: &Arc<Self>,
247        receiver_id: ConnectionId,
248        request: T,
249    ) -> impl Future<Output = Result<T::Response>> {
250        self.request_internal(None, receiver_id, request)
251    }
252
253    pub fn forward_request<T: RequestMessage>(
254        self: &Arc<Self>,
255        sender_id: ConnectionId,
256        receiver_id: ConnectionId,
257        request: T,
258    ) -> impl Future<Output = Result<T::Response>> {
259        self.request_internal(Some(sender_id), receiver_id, request)
260    }
261
262    pub fn request_internal<T: RequestMessage>(
263        self: &Arc<Self>,
264        original_sender_id: Option<ConnectionId>,
265        receiver_id: ConnectionId,
266        request: T,
267    ) -> impl Future<Output = Result<T::Response>> {
268        let this = self.clone();
269        let (tx, mut rx) = oneshot::channel();
270        async move {
271            let connection = this.connection(receiver_id).await?;
272            let message_id = connection
273                .next_message_id
274                .fetch_add(1, atomic::Ordering::SeqCst);
275            connection
276                .response_channels
277                .lock()
278                .await
279                .insert(message_id, tx);
280            connection
281                .writer
282                .lock()
283                .await
284                .write_message(&request.into_envelope(
285                    message_id,
286                    None,
287                    original_sender_id.map(|id| id.0),
288                ))
289                .await?;
290            let response = rx
291                .recv()
292                .await
293                .expect("response channel was unexpectedly dropped");
294            T::Response::from_envelope(response)
295                .ok_or_else(|| anyhow!("received response of the wrong type"))
296        }
297    }
298
299    pub fn send<T: EnvelopedMessage>(
300        self: &Arc<Self>,
301        connection_id: ConnectionId,
302        message: T,
303    ) -> impl Future<Output = Result<()>> {
304        let this = self.clone();
305        async move {
306            let connection = this.connection(connection_id).await?;
307            let message_id = connection
308                .next_message_id
309                .fetch_add(1, atomic::Ordering::SeqCst);
310            connection
311                .writer
312                .lock()
313                .await
314                .write_message(&message.into_envelope(message_id, None, None))
315                .await?;
316            Ok(())
317        }
318    }
319
320    pub fn forward_send<T: EnvelopedMessage>(
321        self: &Arc<Self>,
322        sender_id: ConnectionId,
323        receiver_id: ConnectionId,
324        message: T,
325    ) -> impl Future<Output = Result<()>> {
326        let this = self.clone();
327        async move {
328            let connection = this.connection(receiver_id).await?;
329            let message_id = connection
330                .next_message_id
331                .fetch_add(1, atomic::Ordering::SeqCst);
332            connection
333                .writer
334                .lock()
335                .await
336                .write_message(&message.into_envelope(message_id, None, Some(sender_id.0)))
337                .await?;
338            Ok(())
339        }
340    }
341
342    pub fn respond<T: RequestMessage>(
343        self: &Arc<Self>,
344        receipt: Receipt<T>,
345        response: T::Response,
346    ) -> impl Future<Output = Result<()>> {
347        let this = self.clone();
348        async move {
349            let connection = this.connection(receipt.sender_id).await?;
350            let message_id = connection
351                .next_message_id
352                .fetch_add(1, atomic::Ordering::SeqCst);
353            connection
354                .writer
355                .lock()
356                .await
357                .write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
358                .await?;
359            Ok(())
360        }
361    }
362
363    async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
364        Ok(self
365            .connections
366            .read()
367            .await
368            .get(&id)
369            .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
370            .clone())
371    }
372}
373
374impl fmt::Display for ConnectionId {
375    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376        self.0.fmt(f)
377    }
378}
379
380impl fmt::Display for PeerId {
381    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382        self.0.fmt(f)
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389    use smol::{
390        io::AsyncWriteExt,
391        net::unix::{UnixListener, UnixStream},
392    };
393    use std::io;
394    use tempdir::TempDir;
395
396    #[test]
397    fn test_request_response() {
398        smol::block_on(async move {
399            // create socket
400            let socket_dir_path = TempDir::new("test-request-response").unwrap();
401            let socket_path = socket_dir_path.path().join("test.sock");
402            let listener = UnixListener::bind(&socket_path).unwrap();
403
404            // create 2 clients connected to 1 server
405            let server = Peer::new();
406            let client1 = Peer::new();
407            let client2 = Peer::new();
408            let client1_conn_id = client1
409                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
410                .await;
411            let client2_conn_id = client2
412                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
413                .await;
414            let server_conn_id1 = server
415                .add_connection(listener.accept().await.unwrap().0)
416                .await;
417            let server_conn_id2 = server
418                .add_connection(listener.accept().await.unwrap().0)
419                .await;
420            smol::spawn(client1.handle_messages(client1_conn_id)).detach();
421            smol::spawn(client2.handle_messages(client2_conn_id)).detach();
422            smol::spawn(server.handle_messages(server_conn_id1)).detach();
423            smol::spawn(server.handle_messages(server_conn_id2)).detach();
424
425            // define the expected requests and responses
426            let request1 = proto::Auth {
427                user_id: 1,
428                access_token: "token-1".to_string(),
429            };
430            let response1 = proto::AuthResponse {
431                credentials_valid: true,
432            };
433            let request2 = proto::Auth {
434                user_id: 2,
435                access_token: "token-2".to_string(),
436            };
437            let response2 = proto::AuthResponse {
438                credentials_valid: false,
439            };
440            let request3 = proto::OpenBuffer {
441                worktree_id: 1,
442                path: "path/two".to_string(),
443            };
444            let response3 = proto::OpenBufferResponse {
445                buffer: Some(proto::Buffer {
446                    id: 2,
447                    content: "path/two content".to_string(),
448                    history: vec![],
449                }),
450            };
451            let request4 = proto::OpenBuffer {
452                worktree_id: 2,
453                path: "path/one".to_string(),
454            };
455            let response4 = proto::OpenBufferResponse {
456                buffer: Some(proto::Buffer {
457                    id: 1,
458                    content: "path/one content".to_string(),
459                    history: vec![],
460                }),
461            };
462
463            // on the server, respond to two requests for each client
464            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
465            let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
466            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
467            smol::spawn({
468                let request1 = request1.clone();
469                let request2 = request2.clone();
470                let request3 = request3.clone();
471                let request4 = request4.clone();
472                let response1 = response1.clone();
473                let response2 = response2.clone();
474                let response3 = response3.clone();
475                let response4 = response4.clone();
476                async move {
477                    let msg = auth_rx.recv().await.unwrap();
478                    assert_eq!(msg.payload, request1);
479                    server
480                        .respond(msg.receipt(), response1.clone())
481                        .await
482                        .unwrap();
483
484                    let msg = auth_rx.recv().await.unwrap();
485                    assert_eq!(msg.payload, request2.clone());
486                    server
487                        .respond(msg.receipt(), response2.clone())
488                        .await
489                        .unwrap();
490
491                    let msg = open_buffer_rx.recv().await.unwrap();
492                    assert_eq!(msg.payload, request3.clone());
493                    server
494                        .respond(msg.receipt(), response3.clone())
495                        .await
496                        .unwrap();
497
498                    let msg = open_buffer_rx.recv().await.unwrap();
499                    assert_eq!(msg.payload, request4.clone());
500                    server
501                        .respond(msg.receipt(), response4.clone())
502                        .await
503                        .unwrap();
504
505                    server_done_tx.send(()).await.unwrap();
506                }
507            })
508            .detach();
509
510            assert_eq!(
511                client1.request(client1_conn_id, request1).await.unwrap(),
512                response1
513            );
514            assert_eq!(
515                client2.request(client2_conn_id, request2).await.unwrap(),
516                response2
517            );
518            assert_eq!(
519                client2.request(client2_conn_id, request3).await.unwrap(),
520                response3
521            );
522            assert_eq!(
523                client1.request(client1_conn_id, request4).await.unwrap(),
524                response4
525            );
526
527            client1.disconnect(client1_conn_id).await;
528            client2.disconnect(client1_conn_id).await;
529
530            server_done_rx.recv().await.unwrap();
531        });
532    }
533
534    #[test]
535    fn test_disconnect() {
536        smol::block_on(async move {
537            let socket_dir_path = TempDir::new("drop-client").unwrap();
538            let socket_path = socket_dir_path.path().join(".sock");
539            let listener = UnixListener::bind(&socket_path).unwrap();
540            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
541            let (mut server_conn, _) = listener.accept().await.unwrap();
542
543            let client = Peer::new();
544            let connection_id = client.add_connection(client_conn).await;
545            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
546                barrier::channel();
547            let handle_messages = client.handle_messages(connection_id);
548            smol::spawn(async move {
549                handle_messages.await.ok();
550                incoming_messages_ended_tx.send(()).await.unwrap();
551            })
552            .detach();
553            client.disconnect(connection_id).await;
554
555            incoming_messages_ended_rx.recv().await;
556
557            let err = server_conn.write(&[]).await.unwrap_err();
558            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
559        });
560    }
561
562    #[test]
563    fn test_io_error() {
564        smol::block_on(async move {
565            let socket_dir_path = TempDir::new("io-error").unwrap();
566            let socket_path = socket_dir_path.path().join(".sock");
567            let _listener = UnixListener::bind(&socket_path).unwrap();
568            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
569            client_conn.close().await.unwrap();
570
571            let client = Peer::new();
572            let connection_id = client.add_connection(client_conn).await;
573            smol::spawn(client.handle_messages(connection_id)).detach();
574
575            let err = client
576                .request(
577                    connection_id,
578                    proto::Auth {
579                        user_id: 42,
580                        access_token: "token".to_string(),
581                    },
582                )
583                .await
584                .unwrap_err();
585            assert_eq!(
586                err.downcast_ref::<io::Error>().unwrap().kind(),
587                io::ErrorKind::BrokenPipe
588            );
589        });
590    }
591}