peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{
  5    future::{BoxFuture, Either},
  6    AsyncRead, AsyncWrite, FutureExt,
  7};
  8use postage::{
  9    barrier, mpsc, oneshot,
 10    prelude::{Sink, Stream},
 11};
 12use std::{
 13    any::TypeId,
 14    collections::{HashMap, HashSet},
 15    fmt,
 16    future::Future,
 17    marker::PhantomData,
 18    pin::Pin,
 19    sync::{
 20        atomic::{self, AtomicU32},
 21        Arc,
 22    },
 23};
 24
 25type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
 26type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
 27
 28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 29pub struct ConnectionId(u32);
 30
 31#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 32pub struct PeerId(u32);
 33
 34struct Connection {
 35    writer: Mutex<MessageStream<BoxedWriter>>,
 36    reader: Mutex<MessageStream<BoxedReader>>,
 37    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
 38    next_message_id: AtomicU32,
 39}
 40
 41type MessageHandler = Box<
 42    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 43>;
 44
 45#[derive(Clone, Copy)]
 46pub struct Receipt<T> {
 47    sender_id: ConnectionId,
 48    message_id: u32,
 49    payload_type: PhantomData<T>,
 50}
 51
 52pub struct TypedEnvelope<T> {
 53    pub sender_id: ConnectionId,
 54    pub original_sender_id: Option<PeerId>,
 55    pub message_id: u32,
 56    pub payload: T,
 57}
 58
 59impl<T: RequestMessage> TypedEnvelope<T> {
 60    pub fn receipt(&self) -> Receipt<T> {
 61        Receipt {
 62            sender_id: self.sender_id,
 63            message_id: self.message_id,
 64            payload_type: PhantomData,
 65        }
 66    }
 67}
 68
 69pub struct Peer {
 70    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
 71    connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
 72    message_handlers: RwLock<Vec<MessageHandler>>,
 73    handler_types: Mutex<HashSet<TypeId>>,
 74    next_connection_id: AtomicU32,
 75}
 76
 77impl Peer {
 78    pub fn new() -> Arc<Self> {
 79        Arc::new(Self {
 80            connections: Default::default(),
 81            connection_close_barriers: Default::default(),
 82            message_handlers: Default::default(),
 83            handler_types: Default::default(),
 84            next_connection_id: Default::default(),
 85        })
 86    }
 87
 88    pub async fn add_message_handler<T: EnvelopedMessage>(
 89        &self,
 90    ) -> mpsc::Receiver<TypedEnvelope<T>> {
 91        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
 92            panic!("duplicate handler type");
 93        }
 94
 95        let (tx, rx) = mpsc::channel(256);
 96        self.message_handlers
 97            .write()
 98            .await
 99            .push(Box::new(move |envelope, connection_id| {
100                if envelope.as_ref().map_or(false, T::matches_envelope) {
101                    let envelope = Option::take(envelope).unwrap();
102                    let mut tx = tx.clone();
103                    Some(
104                        async move {
105                            tx.send(TypedEnvelope {
106                                sender_id: connection_id,
107                                original_sender_id: envelope.original_sender_id.map(PeerId),
108                                message_id: envelope.id,
109                                payload: T::from_envelope(envelope).unwrap(),
110                            })
111                            .await
112                            .is_err()
113                        }
114                        .boxed(),
115                    )
116                } else {
117                    None
118                }
119            }));
120        rx
121    }
122
123    pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
124    where
125        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
126    {
127        let connection_id = ConnectionId(
128            self.next_connection_id
129                .fetch_add(1, atomic::Ordering::SeqCst),
130        );
131        self.connections.write().await.insert(
132            connection_id,
133            Arc::new(Connection {
134                reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
135                writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
136                response_channels: Default::default(),
137                next_message_id: Default::default(),
138            }),
139        );
140        connection_id
141    }
142
143    pub async fn disconnect(&self, connection_id: ConnectionId) {
144        self.connections.write().await.remove(&connection_id);
145        self.connection_close_barriers
146            .write()
147            .await
148            .remove(&connection_id);
149    }
150
151    pub fn handle_messages(
152        self: &Arc<Self>,
153        connection_id: ConnectionId,
154    ) -> impl Future<Output = Result<()>> + 'static {
155        let (close_tx, mut close_rx) = barrier::channel();
156        let this = self.clone();
157        async move {
158            this.connection_close_barriers
159                .write()
160                .await
161                .insert(connection_id, close_tx);
162            let connection = this.connection(connection_id).await?;
163            let closed = close_rx.recv();
164            futures::pin_mut!(closed);
165
166            loop {
167                let mut reader = connection.reader.lock().await;
168                let read_message = reader.read_message();
169                futures::pin_mut!(read_message);
170
171                match futures::future::select(read_message, &mut closed).await {
172                    Either::Left((Ok(incoming), _)) => {
173                        if let Some(responding_to) = incoming.responding_to {
174                            let channel = connection
175                                .response_channels
176                                .lock()
177                                .await
178                                .remove(&responding_to);
179                            if let Some(mut tx) = channel {
180                                tx.send(incoming).await.ok();
181                            } else {
182                                log::warn!(
183                                    "received RPC response to unknown request {}",
184                                    responding_to
185                                );
186                            }
187                        } else {
188                            let mut envelope = Some(incoming);
189                            let mut handler_index = None;
190                            let mut handler_was_dropped = false;
191                            for (i, handler) in
192                                this.message_handlers.read().await.iter().enumerate()
193                            {
194                                if let Some(future) = handler(&mut envelope, connection_id) {
195                                    handler_was_dropped = future.await;
196                                    handler_index = Some(i);
197                                    break;
198                                }
199                            }
200
201                            if let Some(handler_index) = handler_index {
202                                if handler_was_dropped {
203                                    drop(this.message_handlers.write().await.remove(handler_index));
204                                }
205                            } else {
206                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
207                            }
208                        }
209                    }
210                    Either::Left((Err(error), _)) => {
211                        log::warn!("received invalid RPC message: {}", error);
212                        Err(error)?;
213                    }
214                    Either::Right(_) => return Ok(()),
215                }
216            }
217        }
218    }
219
220    pub async fn receive<M: EnvelopedMessage>(
221        self: &Arc<Self>,
222        connection_id: ConnectionId,
223    ) -> Result<TypedEnvelope<M>> {
224        let connection = self.connection(connection_id).await?;
225        let envelope = connection.reader.lock().await.read_message().await?;
226        let original_sender_id = envelope.original_sender_id;
227        let message_id = envelope.id;
228        let payload =
229            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
230        Ok(TypedEnvelope {
231            sender_id: connection_id,
232            original_sender_id: original_sender_id.map(PeerId),
233            message_id,
234            payload,
235        })
236    }
237
238    pub fn request<T: RequestMessage>(
239        self: &Arc<Self>,
240        receiver_id: ConnectionId,
241        request: T,
242    ) -> impl Future<Output = Result<T::Response>> {
243        self.request_internal(None, receiver_id, request)
244    }
245
246    pub fn forward_request<T: RequestMessage>(
247        self: &Arc<Self>,
248        sender_id: ConnectionId,
249        receiver_id: ConnectionId,
250        request: T,
251    ) -> impl Future<Output = Result<T::Response>> {
252        self.request_internal(Some(sender_id), receiver_id, request)
253    }
254
255    pub fn request_internal<T: RequestMessage>(
256        self: &Arc<Self>,
257        original_sender_id: Option<ConnectionId>,
258        receiver_id: ConnectionId,
259        request: T,
260    ) -> impl Future<Output = Result<T::Response>> {
261        let this = self.clone();
262        let (tx, mut rx) = oneshot::channel();
263        async move {
264            let connection = this.connection(receiver_id).await?;
265            let message_id = connection
266                .next_message_id
267                .fetch_add(1, atomic::Ordering::SeqCst);
268            connection
269                .response_channels
270                .lock()
271                .await
272                .insert(message_id, tx);
273            connection
274                .writer
275                .lock()
276                .await
277                .write_message(&request.into_envelope(
278                    message_id,
279                    None,
280                    original_sender_id.map(|id| id.0),
281                ))
282                .await?;
283            let response = rx
284                .recv()
285                .await
286                .expect("response channel was unexpectedly dropped");
287            T::Response::from_envelope(response)
288                .ok_or_else(|| anyhow!("received response of the wrong type"))
289        }
290    }
291
292    pub fn send<T: EnvelopedMessage>(
293        self: &Arc<Self>,
294        connection_id: ConnectionId,
295        message: T,
296    ) -> impl Future<Output = Result<()>> {
297        let this = self.clone();
298        async move {
299            let connection = this.connection(connection_id).await?;
300            let message_id = connection
301                .next_message_id
302                .fetch_add(1, atomic::Ordering::SeqCst);
303            connection
304                .writer
305                .lock()
306                .await
307                .write_message(&message.into_envelope(message_id, None, None))
308                .await?;
309            Ok(())
310        }
311    }
312
313    pub fn respond<T: RequestMessage>(
314        self: &Arc<Self>,
315        receipt: Receipt<T>,
316        response: T::Response,
317    ) -> impl Future<Output = Result<()>> {
318        let this = self.clone();
319        async move {
320            let connection = this.connection(receipt.sender_id).await?;
321            let message_id = connection
322                .next_message_id
323                .fetch_add(1, atomic::Ordering::SeqCst);
324            connection
325                .writer
326                .lock()
327                .await
328                .write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
329                .await?;
330            Ok(())
331        }
332    }
333
334    async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
335        Ok(self
336            .connections
337            .read()
338            .await
339            .get(&id)
340            .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
341            .clone())
342    }
343}
344
345impl fmt::Display for ConnectionId {
346    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347        self.0.fmt(f)
348    }
349}
350
351impl fmt::Display for PeerId {
352    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353        self.0.fmt(f)
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use smol::{
361        io::AsyncWriteExt,
362        net::unix::{UnixListener, UnixStream},
363    };
364    use std::io;
365    use tempdir::TempDir;
366
367    #[test]
368    fn test_request_response() {
369        smol::block_on(async move {
370            // create socket
371            let socket_dir_path = TempDir::new("test-request-response").unwrap();
372            let socket_path = socket_dir_path.path().join("test.sock");
373            let listener = UnixListener::bind(&socket_path).unwrap();
374
375            // create 2 clients connected to 1 server
376            let server = Peer::new();
377            let client1 = Peer::new();
378            let client2 = Peer::new();
379            let client1_conn_id = client1
380                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
381                .await;
382            let client2_conn_id = client2
383                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
384                .await;
385            let server_conn_id1 = server
386                .add_connection(listener.accept().await.unwrap().0)
387                .await;
388            let server_conn_id2 = server
389                .add_connection(listener.accept().await.unwrap().0)
390                .await;
391            smol::spawn(client1.handle_messages(client1_conn_id)).detach();
392            smol::spawn(client2.handle_messages(client2_conn_id)).detach();
393            smol::spawn(server.handle_messages(server_conn_id1)).detach();
394            smol::spawn(server.handle_messages(server_conn_id2)).detach();
395
396            // define the expected requests and responses
397            let request1 = proto::Auth {
398                user_id: 1,
399                access_token: "token-1".to_string(),
400            };
401            let response1 = proto::AuthResponse {
402                credentials_valid: true,
403            };
404            let request2 = proto::Auth {
405                user_id: 2,
406                access_token: "token-2".to_string(),
407            };
408            let response2 = proto::AuthResponse {
409                credentials_valid: false,
410            };
411            let request3 = proto::OpenBuffer {
412                worktree_id: 102,
413                path: "path/two".to_string(),
414            };
415            let response3 = proto::OpenBufferResponse {
416                buffer: Some(proto::Buffer {
417                    id: 1001,
418                    path: "path/two".to_string(),
419                    content: "path/two content".to_string(),
420                    history: vec![],
421                }),
422            };
423            let request4 = proto::OpenBuffer {
424                worktree_id: 101,
425                path: "path/one".to_string(),
426            };
427            let response4 = proto::OpenBufferResponse {
428                buffer: Some(proto::Buffer {
429                    id: 1002,
430                    path: "path/one".to_string(),
431                    content: "path/one content".to_string(),
432                    history: vec![],
433                }),
434            };
435
436            // on the server, respond to two requests for each client
437            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
438            let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
439            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
440            smol::spawn({
441                let request1 = request1.clone();
442                let request2 = request2.clone();
443                let request3 = request3.clone();
444                let request4 = request4.clone();
445                let response1 = response1.clone();
446                let response2 = response2.clone();
447                let response3 = response3.clone();
448                let response4 = response4.clone();
449                async move {
450                    let msg = auth_rx.recv().await.unwrap();
451                    assert_eq!(msg.payload, request1);
452                    server
453                        .respond(msg.receipt(), response1.clone())
454                        .await
455                        .unwrap();
456
457                    let msg = auth_rx.recv().await.unwrap();
458                    assert_eq!(msg.payload, request2.clone());
459                    server
460                        .respond(msg.receipt(), response2.clone())
461                        .await
462                        .unwrap();
463
464                    let msg = open_buffer_rx.recv().await.unwrap();
465                    assert_eq!(msg.payload, request3.clone());
466                    server
467                        .respond(msg.receipt(), response3.clone())
468                        .await
469                        .unwrap();
470
471                    let msg = open_buffer_rx.recv().await.unwrap();
472                    assert_eq!(msg.payload, request4.clone());
473                    server
474                        .respond(msg.receipt(), response4.clone())
475                        .await
476                        .unwrap();
477
478                    server_done_tx.send(()).await.unwrap();
479                }
480            })
481            .detach();
482
483            assert_eq!(
484                client1.request(client1_conn_id, request1).await.unwrap(),
485                response1
486            );
487            assert_eq!(
488                client2.request(client2_conn_id, request2).await.unwrap(),
489                response2
490            );
491            assert_eq!(
492                client2.request(client2_conn_id, request3).await.unwrap(),
493                response3
494            );
495            assert_eq!(
496                client1.request(client1_conn_id, request4).await.unwrap(),
497                response4
498            );
499
500            client1.disconnect(client1_conn_id).await;
501            client2.disconnect(client1_conn_id).await;
502
503            server_done_rx.recv().await.unwrap();
504        });
505    }
506
507    #[test]
508    fn test_disconnect() {
509        smol::block_on(async move {
510            let socket_dir_path = TempDir::new("drop-client").unwrap();
511            let socket_path = socket_dir_path.path().join(".sock");
512            let listener = UnixListener::bind(&socket_path).unwrap();
513            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
514            let (mut server_conn, _) = listener.accept().await.unwrap();
515
516            let client = Peer::new();
517            let connection_id = client.add_connection(client_conn).await;
518            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
519                barrier::channel();
520            let handle_messages = client.handle_messages(connection_id);
521            smol::spawn(async move {
522                handle_messages.await.ok();
523                incoming_messages_ended_tx.send(()).await.unwrap();
524            })
525            .detach();
526            client.disconnect(connection_id).await;
527
528            incoming_messages_ended_rx.recv().await;
529
530            let err = server_conn.write(&[]).await.unwrap_err();
531            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
532        });
533    }
534
535    #[test]
536    fn test_io_error() {
537        smol::block_on(async move {
538            let socket_dir_path = TempDir::new("io-error").unwrap();
539            let socket_path = socket_dir_path.path().join(".sock");
540            let _listener = UnixListener::bind(&socket_path).unwrap();
541            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
542            client_conn.close().await.unwrap();
543
544            let client = Peer::new();
545            let connection_id = client.add_connection(client_conn).await;
546            smol::spawn(client.handle_messages(connection_id)).detach();
547
548            let err = client
549                .request(
550                    connection_id,
551                    proto::Auth {
552                        user_id: 42,
553                        access_token: "token".to_string(),
554                    },
555                )
556                .await
557                .unwrap_err();
558            assert_eq!(
559                err.downcast_ref::<io::Error>().unwrap().kind(),
560                io::ErrorKind::BrokenPipe
561            );
562        });
563    }
564}