peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{
  5    future::{BoxFuture, Either},
  6    AsyncRead, AsyncWrite, FutureExt,
  7};
  8use postage::{
  9    barrier, mpsc, oneshot,
 10    prelude::{Sink, Stream},
 11};
 12use std::{
 13    any::TypeId,
 14    collections::{HashMap, HashSet},
 15    fmt,
 16    future::Future,
 17    marker::PhantomData,
 18    pin::Pin,
 19    sync::{
 20        atomic::{self, AtomicU32},
 21        Arc,
 22    },
 23};
 24
 25type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
 26type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
 27
 28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 29pub struct ConnectionId(u32);
 30
 31#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 32pub struct PeerId(u32);
 33
 34struct Connection {
 35    writer: Mutex<MessageStream<BoxedWriter>>,
 36    reader: Mutex<MessageStream<BoxedReader>>,
 37    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
 38    next_message_id: AtomicU32,
 39}
 40
 41type MessageHandler = Box<
 42    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 43>;
 44
 45#[derive(Clone, Copy)]
 46pub struct Receipt<T> {
 47    sender_id: ConnectionId,
 48    message_id: u32,
 49    payload_type: PhantomData<T>,
 50}
 51
 52pub struct TypedEnvelope<T> {
 53    pub sender_id: ConnectionId,
 54    pub original_sender_id: Option<PeerId>,
 55    pub message_id: u32,
 56    pub payload: T,
 57}
 58
 59impl<T: RequestMessage> TypedEnvelope<T> {
 60    pub fn receipt(&self) -> Receipt<T> {
 61        Receipt {
 62            sender_id: self.sender_id,
 63            message_id: self.message_id,
 64            payload_type: PhantomData,
 65        }
 66    }
 67}
 68
 69pub struct Peer {
 70    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
 71    connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
 72    message_handlers: RwLock<Vec<MessageHandler>>,
 73    handler_types: Mutex<HashSet<TypeId>>,
 74    next_connection_id: AtomicU32,
 75}
 76
 77impl Peer {
 78    pub fn new() -> Arc<Self> {
 79        Arc::new(Self {
 80            connections: Default::default(),
 81            connection_close_barriers: Default::default(),
 82            message_handlers: Default::default(),
 83            handler_types: Default::default(),
 84            next_connection_id: Default::default(),
 85        })
 86    }
 87
 88    pub async fn add_message_handler<T: EnvelopedMessage>(
 89        &self,
 90    ) -> mpsc::Receiver<TypedEnvelope<T>> {
 91        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
 92            panic!("duplicate handler type");
 93        }
 94
 95        let (tx, rx) = mpsc::channel(256);
 96        self.message_handlers
 97            .write()
 98            .await
 99            .push(Box::new(move |envelope, connection_id| {
100                if envelope.as_ref().map_or(false, T::matches_envelope) {
101                    let envelope = Option::take(envelope).unwrap();
102                    let mut tx = tx.clone();
103                    Some(
104                        async move {
105                            tx.send(TypedEnvelope {
106                                sender_id: connection_id,
107                                original_sender_id: envelope.original_sender_id.map(PeerId),
108                                message_id: envelope.id,
109                                payload: T::from_envelope(envelope).unwrap(),
110                            })
111                            .await
112                            .is_err()
113                        }
114                        .boxed(),
115                    )
116                } else {
117                    None
118                }
119            }));
120        rx
121    }
122
123    pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
124    where
125        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
126    {
127        let connection_id = ConnectionId(
128            self.next_connection_id
129                .fetch_add(1, atomic::Ordering::SeqCst),
130        );
131        self.connections.write().await.insert(
132            connection_id,
133            Arc::new(Connection {
134                reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
135                writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
136                response_channels: Default::default(),
137                next_message_id: Default::default(),
138            }),
139        );
140        connection_id
141    }
142
143    pub async fn disconnect(&self, connection_id: ConnectionId) {
144        self.connections.write().await.remove(&connection_id);
145        self.connection_close_barriers
146            .write()
147            .await
148            .remove(&connection_id);
149    }
150
151    pub fn handle_messages(
152        self: &Arc<Self>,
153        connection_id: ConnectionId,
154    ) -> impl Future<Output = Result<()>> + 'static {
155        let (close_tx, mut close_rx) = barrier::channel();
156        let this = self.clone();
157        async move {
158            this.connection_close_barriers
159                .write()
160                .await
161                .insert(connection_id, close_tx);
162            let connection = this.connection(connection_id).await?;
163            let closed = close_rx.recv();
164            futures::pin_mut!(closed);
165
166            loop {
167                let mut reader = connection.reader.lock().await;
168                let read_message = reader.read_message();
169                futures::pin_mut!(read_message);
170
171                match futures::future::select(read_message, &mut closed).await {
172                    Either::Left((Ok(incoming), _)) => {
173                        if let Some(responding_to) = incoming.responding_to {
174                            let channel = connection
175                                .response_channels
176                                .lock()
177                                .await
178                                .remove(&responding_to);
179                            if let Some(mut tx) = channel {
180                                tx.send(incoming).await.ok();
181                            } else {
182                                log::warn!(
183                                    "received RPC response to unknown request {}",
184                                    responding_to
185                                );
186                            }
187                        } else {
188                            let mut envelope = Some(incoming);
189                            let mut handler_index = None;
190                            let mut handler_was_dropped = false;
191                            for (i, handler) in
192                                this.message_handlers.read().await.iter().enumerate()
193                            {
194                                if let Some(future) = handler(&mut envelope, connection_id) {
195                                    handler_was_dropped = future.await;
196                                    handler_index = Some(i);
197                                    break;
198                                }
199                            }
200
201                            if let Some(handler_index) = handler_index {
202                                if handler_was_dropped {
203                                    drop(this.message_handlers.write().await.remove(handler_index));
204                                }
205                            } else {
206                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
207                            }
208                        }
209                    }
210                    Either::Left((Err(error), _)) => {
211                        log::warn!("received invalid RPC message: {}", error);
212                        Err(error)?;
213                    }
214                    Either::Right(_) => return Ok(()),
215                }
216            }
217        }
218    }
219
220    pub async fn receive<M: EnvelopedMessage>(
221        self: &Arc<Self>,
222        connection_id: ConnectionId,
223    ) -> Result<TypedEnvelope<M>> {
224        let connection = self.connection(connection_id).await?;
225        let envelope = connection.reader.lock().await.read_message().await?;
226        let original_sender_id = envelope.original_sender_id;
227        let message_id = envelope.id;
228        let payload =
229            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
230        Ok(TypedEnvelope {
231            sender_id: connection_id,
232            original_sender_id: original_sender_id.map(PeerId),
233            message_id,
234            payload,
235        })
236    }
237
238    pub fn request<T: RequestMessage>(
239        self: &Arc<Self>,
240        receiver_id: ConnectionId,
241        request: T,
242    ) -> impl Future<Output = Result<T::Response>> {
243        self.request_internal(None, receiver_id, request)
244    }
245
246    pub fn forward_request<T: RequestMessage>(
247        self: &Arc<Self>,
248        sender_id: ConnectionId,
249        receiver_id: ConnectionId,
250        request: T,
251    ) -> impl Future<Output = Result<T::Response>> {
252        self.request_internal(Some(sender_id), receiver_id, request)
253    }
254
255    pub fn request_internal<T: RequestMessage>(
256        self: &Arc<Self>,
257        original_sender_id: Option<ConnectionId>,
258        receiver_id: ConnectionId,
259        request: T,
260    ) -> impl Future<Output = Result<T::Response>> {
261        let this = self.clone();
262        let (tx, mut rx) = oneshot::channel();
263        async move {
264            let connection = this.connection(receiver_id).await?;
265            let message_id = connection
266                .next_message_id
267                .fetch_add(1, atomic::Ordering::SeqCst);
268            connection
269                .response_channels
270                .lock()
271                .await
272                .insert(message_id, tx);
273            connection
274                .writer
275                .lock()
276                .await
277                .write_message(&request.into_envelope(
278                    message_id,
279                    None,
280                    original_sender_id.map(|id| id.0),
281                ))
282                .await?;
283            let response = rx
284                .recv()
285                .await
286                .expect("response channel was unexpectedly dropped");
287            T::Response::from_envelope(response)
288                .ok_or_else(|| anyhow!("received response of the wrong type"))
289        }
290    }
291
292    pub fn send<T: EnvelopedMessage>(
293        self: &Arc<Self>,
294        connection_id: ConnectionId,
295        message: T,
296    ) -> impl Future<Output = Result<()>> {
297        let this = self.clone();
298        async move {
299            let connection = this.connection(connection_id).await?;
300            let message_id = connection
301                .next_message_id
302                .fetch_add(1, atomic::Ordering::SeqCst);
303            connection
304                .writer
305                .lock()
306                .await
307                .write_message(&message.into_envelope(message_id, None, None))
308                .await?;
309            Ok(())
310        }
311    }
312
313    pub fn forward_send<T: EnvelopedMessage>(
314        self: &Arc<Self>,
315        sender_id: ConnectionId,
316        receiver_id: ConnectionId,
317        message: T,
318    ) -> impl Future<Output = Result<()>> {
319        let this = self.clone();
320        async move {
321            let connection = this.connection(receiver_id).await?;
322            let message_id = connection
323                .next_message_id
324                .fetch_add(1, atomic::Ordering::SeqCst);
325            connection
326                .writer
327                .lock()
328                .await
329                .write_message(&message.into_envelope(message_id, None, Some(sender_id.0)))
330                .await?;
331            Ok(())
332        }
333    }
334
335    pub fn respond<T: RequestMessage>(
336        self: &Arc<Self>,
337        receipt: Receipt<T>,
338        response: T::Response,
339    ) -> impl Future<Output = Result<()>> {
340        let this = self.clone();
341        async move {
342            let connection = this.connection(receipt.sender_id).await?;
343            let message_id = connection
344                .next_message_id
345                .fetch_add(1, atomic::Ordering::SeqCst);
346            connection
347                .writer
348                .lock()
349                .await
350                .write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
351                .await?;
352            Ok(())
353        }
354    }
355
356    async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
357        Ok(self
358            .connections
359            .read()
360            .await
361            .get(&id)
362            .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
363            .clone())
364    }
365}
366
367impl fmt::Display for ConnectionId {
368    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
369        self.0.fmt(f)
370    }
371}
372
373impl fmt::Display for PeerId {
374    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
375        self.0.fmt(f)
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use smol::{
383        io::AsyncWriteExt,
384        net::unix::{UnixListener, UnixStream},
385    };
386    use std::io;
387    use tempdir::TempDir;
388
389    #[test]
390    fn test_request_response() {
391        smol::block_on(async move {
392            // create socket
393            let socket_dir_path = TempDir::new("test-request-response").unwrap();
394            let socket_path = socket_dir_path.path().join("test.sock");
395            let listener = UnixListener::bind(&socket_path).unwrap();
396
397            // create 2 clients connected to 1 server
398            let server = Peer::new();
399            let client1 = Peer::new();
400            let client2 = Peer::new();
401            let client1_conn_id = client1
402                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
403                .await;
404            let client2_conn_id = client2
405                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
406                .await;
407            let server_conn_id1 = server
408                .add_connection(listener.accept().await.unwrap().0)
409                .await;
410            let server_conn_id2 = server
411                .add_connection(listener.accept().await.unwrap().0)
412                .await;
413            smol::spawn(client1.handle_messages(client1_conn_id)).detach();
414            smol::spawn(client2.handle_messages(client2_conn_id)).detach();
415            smol::spawn(server.handle_messages(server_conn_id1)).detach();
416            smol::spawn(server.handle_messages(server_conn_id2)).detach();
417
418            // define the expected requests and responses
419            let request1 = proto::Auth {
420                user_id: 1,
421                access_token: "token-1".to_string(),
422            };
423            let response1 = proto::AuthResponse {
424                credentials_valid: true,
425            };
426            let request2 = proto::Auth {
427                user_id: 2,
428                access_token: "token-2".to_string(),
429            };
430            let response2 = proto::AuthResponse {
431                credentials_valid: false,
432            };
433            let request3 = proto::OpenBuffer {
434                worktree_id: 1,
435                id: 2,
436            };
437            let response3 = proto::OpenBufferResponse {
438                buffer: Some(proto::Buffer {
439                    content: "path/two content".to_string(),
440                    history: vec![],
441                }),
442            };
443            let request4 = proto::OpenBuffer {
444                worktree_id: 2,
445                id: 1,
446            };
447            let response4 = proto::OpenBufferResponse {
448                buffer: Some(proto::Buffer {
449                    content: "path/one content".to_string(),
450                    history: vec![],
451                }),
452            };
453
454            // on the server, respond to two requests for each client
455            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
456            let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
457            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
458            smol::spawn({
459                let request1 = request1.clone();
460                let request2 = request2.clone();
461                let request3 = request3.clone();
462                let request4 = request4.clone();
463                let response1 = response1.clone();
464                let response2 = response2.clone();
465                let response3 = response3.clone();
466                let response4 = response4.clone();
467                async move {
468                    let msg = auth_rx.recv().await.unwrap();
469                    assert_eq!(msg.payload, request1);
470                    server
471                        .respond(msg.receipt(), response1.clone())
472                        .await
473                        .unwrap();
474
475                    let msg = auth_rx.recv().await.unwrap();
476                    assert_eq!(msg.payload, request2.clone());
477                    server
478                        .respond(msg.receipt(), response2.clone())
479                        .await
480                        .unwrap();
481
482                    let msg = open_buffer_rx.recv().await.unwrap();
483                    assert_eq!(msg.payload, request3.clone());
484                    server
485                        .respond(msg.receipt(), response3.clone())
486                        .await
487                        .unwrap();
488
489                    let msg = open_buffer_rx.recv().await.unwrap();
490                    assert_eq!(msg.payload, request4.clone());
491                    server
492                        .respond(msg.receipt(), response4.clone())
493                        .await
494                        .unwrap();
495
496                    server_done_tx.send(()).await.unwrap();
497                }
498            })
499            .detach();
500
501            assert_eq!(
502                client1.request(client1_conn_id, request1).await.unwrap(),
503                response1
504            );
505            assert_eq!(
506                client2.request(client2_conn_id, request2).await.unwrap(),
507                response2
508            );
509            assert_eq!(
510                client2.request(client2_conn_id, request3).await.unwrap(),
511                response3
512            );
513            assert_eq!(
514                client1.request(client1_conn_id, request4).await.unwrap(),
515                response4
516            );
517
518            client1.disconnect(client1_conn_id).await;
519            client2.disconnect(client1_conn_id).await;
520
521            server_done_rx.recv().await.unwrap();
522        });
523    }
524
525    #[test]
526    fn test_disconnect() {
527        smol::block_on(async move {
528            let socket_dir_path = TempDir::new("drop-client").unwrap();
529            let socket_path = socket_dir_path.path().join(".sock");
530            let listener = UnixListener::bind(&socket_path).unwrap();
531            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
532            let (mut server_conn, _) = listener.accept().await.unwrap();
533
534            let client = Peer::new();
535            let connection_id = client.add_connection(client_conn).await;
536            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
537                barrier::channel();
538            let handle_messages = client.handle_messages(connection_id);
539            smol::spawn(async move {
540                handle_messages.await.ok();
541                incoming_messages_ended_tx.send(()).await.unwrap();
542            })
543            .detach();
544            client.disconnect(connection_id).await;
545
546            incoming_messages_ended_rx.recv().await;
547
548            let err = server_conn.write(&[]).await.unwrap_err();
549            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
550        });
551    }
552
553    #[test]
554    fn test_io_error() {
555        smol::block_on(async move {
556            let socket_dir_path = TempDir::new("io-error").unwrap();
557            let socket_path = socket_dir_path.path().join(".sock");
558            let _listener = UnixListener::bind(&socket_path).unwrap();
559            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
560            client_conn.close().await.unwrap();
561
562            let client = Peer::new();
563            let connection_id = client.add_connection(client_conn).await;
564            smol::spawn(client.handle_messages(connection_id)).detach();
565
566            let err = client
567                .request(
568                    connection_id,
569                    proto::Auth {
570                        user_id: 42,
571                        access_token: "token".to_string(),
572                    },
573                )
574                .await
575                .unwrap_err();
576            assert_eq!(
577                err.downcast_ref::<io::Error>().unwrap().kind(),
578                io::ErrorKind::BrokenPipe
579            );
580        });
581    }
582}