peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{
  5    future::{BoxFuture, Either},
  6    AsyncRead, AsyncWrite, FutureExt,
  7};
  8use postage::{
  9    barrier, mpsc, oneshot,
 10    prelude::{Sink, Stream},
 11};
 12use std::{
 13    any::TypeId,
 14    collections::{HashMap, HashSet},
 15    fmt,
 16    future::Future,
 17    marker::PhantomData,
 18    pin::Pin,
 19    sync::{
 20        atomic::{self, AtomicU32},
 21        Arc,
 22    },
 23};
 24
 25type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
 26type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
 27
 28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 29pub struct ConnectionId(pub u32);
 30
 31#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 32pub struct PeerId(pub u32);
 33
 34struct Connection {
 35    writer: Mutex<MessageStream<BoxedWriter>>,
 36    reader: Mutex<MessageStream<BoxedReader>>,
 37    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
 38    next_message_id: AtomicU32,
 39}
 40
 41type MessageHandler = Box<
 42    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 43>;
 44
 45#[derive(Clone, Copy)]
 46pub struct Receipt<T> {
 47    sender_id: ConnectionId,
 48    message_id: u32,
 49    payload_type: PhantomData<T>,
 50}
 51
 52pub struct TypedEnvelope<T> {
 53    pub sender_id: ConnectionId,
 54    original_sender_id: Option<PeerId>,
 55    pub message_id: u32,
 56    pub payload: T,
 57}
 58
 59impl<T> TypedEnvelope<T> {
 60    pub fn original_sender_id(&self) -> Result<PeerId> {
 61        self.original_sender_id
 62            .ok_or_else(|| anyhow!("missing original_sender_id"))
 63    }
 64}
 65
 66impl<T: RequestMessage> TypedEnvelope<T> {
 67    pub fn receipt(&self) -> Receipt<T> {
 68        Receipt {
 69            sender_id: self.sender_id,
 70            message_id: self.message_id,
 71            payload_type: PhantomData,
 72        }
 73    }
 74}
 75
 76pub struct Peer {
 77    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
 78    connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
 79    message_handlers: RwLock<Vec<MessageHandler>>,
 80    handler_types: Mutex<HashSet<TypeId>>,
 81    next_connection_id: AtomicU32,
 82}
 83
 84impl Peer {
 85    pub fn new() -> Arc<Self> {
 86        Arc::new(Self {
 87            connections: Default::default(),
 88            connection_close_barriers: Default::default(),
 89            message_handlers: Default::default(),
 90            handler_types: Default::default(),
 91            next_connection_id: Default::default(),
 92        })
 93    }
 94
 95    pub async fn add_message_handler<T: EnvelopedMessage>(
 96        &self,
 97    ) -> mpsc::Receiver<TypedEnvelope<T>> {
 98        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
 99            panic!("duplicate handler type");
100        }
101
102        let (tx, rx) = mpsc::channel(256);
103        self.message_handlers
104            .write()
105            .await
106            .push(Box::new(move |envelope, connection_id| {
107                if envelope.as_ref().map_or(false, T::matches_envelope) {
108                    let envelope = Option::take(envelope).unwrap();
109                    let mut tx = tx.clone();
110                    Some(
111                        async move {
112                            tx.send(TypedEnvelope {
113                                sender_id: connection_id,
114                                original_sender_id: envelope.original_sender_id.map(PeerId),
115                                message_id: envelope.id,
116                                payload: T::from_envelope(envelope).unwrap(),
117                            })
118                            .await
119                            .is_err()
120                        }
121                        .boxed(),
122                    )
123                } else {
124                    None
125                }
126            }));
127        rx
128    }
129
130    pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
131    where
132        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
133    {
134        let connection_id = ConnectionId(
135            self.next_connection_id
136                .fetch_add(1, atomic::Ordering::SeqCst),
137        );
138        self.connections.write().await.insert(
139            connection_id,
140            Arc::new(Connection {
141                reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
142                writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
143                response_channels: Default::default(),
144                next_message_id: Default::default(),
145            }),
146        );
147        connection_id
148    }
149
150    pub async fn disconnect(&self, connection_id: ConnectionId) {
151        self.connections.write().await.remove(&connection_id);
152        self.connection_close_barriers
153            .write()
154            .await
155            .remove(&connection_id);
156    }
157
158    pub async fn reset(&self) {
159        self.connections.write().await.clear();
160        self.connection_close_barriers.write().await.clear();
161        self.handler_types.lock().await.clear();
162        self.message_handlers.write().await.clear();
163    }
164
165    pub fn handle_messages(
166        self: &Arc<Self>,
167        connection_id: ConnectionId,
168    ) -> impl Future<Output = Result<()>> + 'static {
169        let (close_tx, mut close_rx) = barrier::channel();
170        let this = self.clone();
171        async move {
172            this.connection_close_barriers
173                .write()
174                .await
175                .insert(connection_id, close_tx);
176            let connection = this.connection(connection_id).await?;
177            let closed = close_rx.recv();
178            futures::pin_mut!(closed);
179
180            loop {
181                let mut reader = connection.reader.lock().await;
182                let read_message = reader.read_message();
183                futures::pin_mut!(read_message);
184
185                match futures::future::select(read_message, &mut closed).await {
186                    Either::Left((Ok(incoming), _)) => {
187                        if let Some(responding_to) = incoming.responding_to {
188                            let channel = connection
189                                .response_channels
190                                .lock()
191                                .await
192                                .remove(&responding_to);
193                            if let Some(mut tx) = channel {
194                                tx.send(incoming).await.ok();
195                            } else {
196                                log::warn!(
197                                    "received RPC response to unknown request {}",
198                                    responding_to
199                                );
200                            }
201                        } else {
202                            let mut envelope = Some(incoming);
203                            let mut handler_index = None;
204                            let mut handler_was_dropped = false;
205                            for (i, handler) in
206                                this.message_handlers.read().await.iter().enumerate()
207                            {
208                                if let Some(future) = handler(&mut envelope, connection_id) {
209                                    handler_was_dropped = future.await;
210                                    handler_index = Some(i);
211                                    break;
212                                }
213                            }
214
215                            if let Some(handler_index) = handler_index {
216                                if handler_was_dropped {
217                                    drop(this.message_handlers.write().await.remove(handler_index));
218                                }
219                            } else {
220                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
221                            }
222                        }
223                    }
224                    Either::Left((Err(error), _)) => {
225                        log::warn!("received invalid RPC message: {}", error);
226                        Err(error)?;
227                    }
228                    Either::Right(_) => return Ok(()),
229                }
230            }
231        }
232    }
233
234    pub async fn receive<M: EnvelopedMessage>(
235        self: &Arc<Self>,
236        connection_id: ConnectionId,
237    ) -> Result<TypedEnvelope<M>> {
238        let connection = self.connection(connection_id).await?;
239        let envelope = connection.reader.lock().await.read_message().await?;
240        let original_sender_id = envelope.original_sender_id;
241        let message_id = envelope.id;
242        let payload =
243            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
244        Ok(TypedEnvelope {
245            sender_id: connection_id,
246            original_sender_id: original_sender_id.map(PeerId),
247            message_id,
248            payload,
249        })
250    }
251
252    pub fn request<T: RequestMessage>(
253        self: &Arc<Self>,
254        receiver_id: ConnectionId,
255        request: T,
256    ) -> impl Future<Output = Result<T::Response>> {
257        self.request_internal(None, receiver_id, request)
258    }
259
260    pub fn forward_request<T: RequestMessage>(
261        self: &Arc<Self>,
262        sender_id: ConnectionId,
263        receiver_id: ConnectionId,
264        request: T,
265    ) -> impl Future<Output = Result<T::Response>> {
266        self.request_internal(Some(sender_id), receiver_id, request)
267    }
268
269    pub fn request_internal<T: RequestMessage>(
270        self: &Arc<Self>,
271        original_sender_id: Option<ConnectionId>,
272        receiver_id: ConnectionId,
273        request: T,
274    ) -> impl Future<Output = Result<T::Response>> {
275        let this = self.clone();
276        let (tx, mut rx) = oneshot::channel();
277        async move {
278            let connection = this.connection(receiver_id).await?;
279            let message_id = connection
280                .next_message_id
281                .fetch_add(1, atomic::Ordering::SeqCst);
282            connection
283                .response_channels
284                .lock()
285                .await
286                .insert(message_id, tx);
287            connection
288                .writer
289                .lock()
290                .await
291                .write_message(&request.into_envelope(
292                    message_id,
293                    None,
294                    original_sender_id.map(|id| id.0),
295                ))
296                .await?;
297            let response = rx
298                .recv()
299                .await
300                .expect("response channel was unexpectedly dropped");
301            T::Response::from_envelope(response)
302                .ok_or_else(|| anyhow!("received response of the wrong type"))
303        }
304    }
305
306    pub fn send<T: EnvelopedMessage>(
307        self: &Arc<Self>,
308        connection_id: ConnectionId,
309        message: T,
310    ) -> impl Future<Output = Result<()>> {
311        let this = self.clone();
312        async move {
313            let connection = this.connection(connection_id).await?;
314            let message_id = connection
315                .next_message_id
316                .fetch_add(1, atomic::Ordering::SeqCst);
317            connection
318                .writer
319                .lock()
320                .await
321                .write_message(&message.into_envelope(message_id, None, None))
322                .await?;
323            Ok(())
324        }
325    }
326
327    pub fn forward_send<T: EnvelopedMessage>(
328        self: &Arc<Self>,
329        sender_id: ConnectionId,
330        receiver_id: ConnectionId,
331        message: T,
332    ) -> impl Future<Output = Result<()>> {
333        let this = self.clone();
334        async move {
335            let connection = this.connection(receiver_id).await?;
336            let message_id = connection
337                .next_message_id
338                .fetch_add(1, atomic::Ordering::SeqCst);
339            connection
340                .writer
341                .lock()
342                .await
343                .write_message(&message.into_envelope(message_id, None, Some(sender_id.0)))
344                .await?;
345            Ok(())
346        }
347    }
348
349    pub fn respond<T: RequestMessage>(
350        self: &Arc<Self>,
351        receipt: Receipt<T>,
352        response: T::Response,
353    ) -> impl Future<Output = Result<()>> {
354        let this = self.clone();
355        async move {
356            let connection = this.connection(receipt.sender_id).await?;
357            let message_id = connection
358                .next_message_id
359                .fetch_add(1, atomic::Ordering::SeqCst);
360            connection
361                .writer
362                .lock()
363                .await
364                .write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
365                .await?;
366            Ok(())
367        }
368    }
369
370    async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
371        Ok(self
372            .connections
373            .read()
374            .await
375            .get(&id)
376            .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
377            .clone())
378    }
379}
380
381impl fmt::Display for ConnectionId {
382    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
383        self.0.fmt(f)
384    }
385}
386
387impl fmt::Display for PeerId {
388    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
389        self.0.fmt(f)
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use smol::{
397        io::AsyncWriteExt,
398        net::unix::{UnixListener, UnixStream},
399    };
400    use std::io;
401    use tempdir::TempDir;
402
403    #[test]
404    fn test_request_response() {
405        smol::block_on(async move {
406            // create socket
407            let socket_dir_path = TempDir::new("test-request-response").unwrap();
408            let socket_path = socket_dir_path.path().join("test.sock");
409            let listener = UnixListener::bind(&socket_path).unwrap();
410
411            // create 2 clients connected to 1 server
412            let server = Peer::new();
413            let client1 = Peer::new();
414            let client2 = Peer::new();
415            let client1_conn_id = client1
416                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
417                .await;
418            let client2_conn_id = client2
419                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
420                .await;
421            let server_conn_id1 = server
422                .add_connection(listener.accept().await.unwrap().0)
423                .await;
424            let server_conn_id2 = server
425                .add_connection(listener.accept().await.unwrap().0)
426                .await;
427            smol::spawn(client1.handle_messages(client1_conn_id)).detach();
428            smol::spawn(client2.handle_messages(client2_conn_id)).detach();
429            smol::spawn(server.handle_messages(server_conn_id1)).detach();
430            smol::spawn(server.handle_messages(server_conn_id2)).detach();
431
432            // define the expected requests and responses
433            let request1 = proto::Auth {
434                user_id: 1,
435                access_token: "token-1".to_string(),
436            };
437            let response1 = proto::AuthResponse {
438                credentials_valid: true,
439            };
440            let request2 = proto::Auth {
441                user_id: 2,
442                access_token: "token-2".to_string(),
443            };
444            let response2 = proto::AuthResponse {
445                credentials_valid: false,
446            };
447            let request3 = proto::OpenBuffer {
448                worktree_id: 1,
449                path: "path/two".to_string(),
450            };
451            let response3 = proto::OpenBufferResponse {
452                buffer: Some(proto::Buffer {
453                    id: 2,
454                    content: "path/two content".to_string(),
455                    history: vec![],
456                }),
457            };
458            let request4 = proto::OpenBuffer {
459                worktree_id: 2,
460                path: "path/one".to_string(),
461            };
462            let response4 = proto::OpenBufferResponse {
463                buffer: Some(proto::Buffer {
464                    id: 1,
465                    content: "path/one content".to_string(),
466                    history: vec![],
467                }),
468            };
469
470            // on the server, respond to two requests for each client
471            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
472            let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
473            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
474            smol::spawn({
475                let request1 = request1.clone();
476                let request2 = request2.clone();
477                let request3 = request3.clone();
478                let request4 = request4.clone();
479                let response1 = response1.clone();
480                let response2 = response2.clone();
481                let response3 = response3.clone();
482                let response4 = response4.clone();
483                async move {
484                    let msg = auth_rx.recv().await.unwrap();
485                    assert_eq!(msg.payload, request1);
486                    server
487                        .respond(msg.receipt(), response1.clone())
488                        .await
489                        .unwrap();
490
491                    let msg = auth_rx.recv().await.unwrap();
492                    assert_eq!(msg.payload, request2.clone());
493                    server
494                        .respond(msg.receipt(), response2.clone())
495                        .await
496                        .unwrap();
497
498                    let msg = open_buffer_rx.recv().await.unwrap();
499                    assert_eq!(msg.payload, request3.clone());
500                    server
501                        .respond(msg.receipt(), response3.clone())
502                        .await
503                        .unwrap();
504
505                    let msg = open_buffer_rx.recv().await.unwrap();
506                    assert_eq!(msg.payload, request4.clone());
507                    server
508                        .respond(msg.receipt(), response4.clone())
509                        .await
510                        .unwrap();
511
512                    server_done_tx.send(()).await.unwrap();
513                }
514            })
515            .detach();
516
517            assert_eq!(
518                client1.request(client1_conn_id, request1).await.unwrap(),
519                response1
520            );
521            assert_eq!(
522                client2.request(client2_conn_id, request2).await.unwrap(),
523                response2
524            );
525            assert_eq!(
526                client2.request(client2_conn_id, request3).await.unwrap(),
527                response3
528            );
529            assert_eq!(
530                client1.request(client1_conn_id, request4).await.unwrap(),
531                response4
532            );
533
534            client1.disconnect(client1_conn_id).await;
535            client2.disconnect(client1_conn_id).await;
536
537            server_done_rx.recv().await.unwrap();
538        });
539    }
540
541    #[test]
542    fn test_disconnect() {
543        smol::block_on(async move {
544            let socket_dir_path = TempDir::new("drop-client").unwrap();
545            let socket_path = socket_dir_path.path().join(".sock");
546            let listener = UnixListener::bind(&socket_path).unwrap();
547            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
548            let (mut server_conn, _) = listener.accept().await.unwrap();
549
550            let client = Peer::new();
551            let connection_id = client.add_connection(client_conn).await;
552            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
553                barrier::channel();
554            let handle_messages = client.handle_messages(connection_id);
555            smol::spawn(async move {
556                handle_messages.await.ok();
557                incoming_messages_ended_tx.send(()).await.unwrap();
558            })
559            .detach();
560            client.disconnect(connection_id).await;
561
562            incoming_messages_ended_rx.recv().await;
563
564            let err = server_conn.write(&[]).await.unwrap_err();
565            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
566        });
567    }
568
569    #[test]
570    fn test_io_error() {
571        smol::block_on(async move {
572            let socket_dir_path = TempDir::new("io-error").unwrap();
573            let socket_path = socket_dir_path.path().join(".sock");
574            let _listener = UnixListener::bind(&socket_path).unwrap();
575            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
576            client_conn.close().await.unwrap();
577
578            let client = Peer::new();
579            let connection_id = client.add_connection(client_conn).await;
580            smol::spawn(client.handle_messages(connection_id)).detach();
581
582            let err = client
583                .request(
584                    connection_id,
585                    proto::Auth {
586                        user_id: 42,
587                        access_token: "token".to_string(),
588                    },
589                )
590                .await
591                .unwrap_err();
592            assert_eq!(
593                err.downcast_ref::<io::Error>().unwrap().kind(),
594                io::ErrorKind::BrokenPipe
595            );
596        });
597    }
598}