peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Context, Result};
  3use async_lock::{Mutex, RwLock};
  4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  5use futures::{
  6    future::{self, BoxFuture, LocalBoxFuture},
  7    FutureExt, Stream, StreamExt,
  8};
  9use postage::{
 10    broadcast, mpsc,
 11    prelude::{Sink as _, Stream as _},
 12};
 13use std::{
 14    any::{Any, TypeId},
 15    collections::{HashMap, HashSet},
 16    fmt,
 17    future::Future,
 18    marker::PhantomData,
 19    sync::{
 20        atomic::{self, AtomicU32},
 21        Arc,
 22    },
 23};
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 26pub struct ConnectionId(pub u32);
 27
 28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 29pub struct PeerId(pub u32);
 30
 31type MessageHandler = Box<
 32    dyn Send
 33        + Sync
 34        + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
 35>;
 36
 37type ForegroundMessageHandler =
 38    Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
 39
 40pub struct Receipt<T> {
 41    pub sender_id: ConnectionId,
 42    pub message_id: u32,
 43    payload_type: PhantomData<T>,
 44}
 45
 46pub struct TypedEnvelope<T> {
 47    pub sender_id: ConnectionId,
 48    pub original_sender_id: Option<PeerId>,
 49    pub message_id: u32,
 50    pub payload: T,
 51}
 52
 53impl<T> TypedEnvelope<T> {
 54    pub fn original_sender_id(&self) -> Result<PeerId> {
 55        self.original_sender_id
 56            .ok_or_else(|| anyhow!("missing original_sender_id"))
 57    }
 58}
 59
 60impl<T: RequestMessage> TypedEnvelope<T> {
 61    pub fn receipt(&self) -> Receipt<T> {
 62        Receipt {
 63            sender_id: self.sender_id,
 64            message_id: self.message_id,
 65            payload_type: PhantomData,
 66        }
 67    }
 68}
 69
 70pub type Router = RouterInternal<MessageHandler>;
 71pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
 72pub struct RouterInternal<H> {
 73    message_handlers: Vec<H>,
 74    handler_types: HashSet<TypeId>,
 75}
 76
 77pub struct Peer {
 78    connections: RwLock<HashMap<ConnectionId, Connection>>,
 79    next_connection_id: AtomicU32,
 80    incoming_messages: broadcast::Sender<Arc<dyn Any + Send + Sync>>,
 81}
 82
 83#[derive(Clone)]
 84struct Connection {
 85    outgoing_tx: mpsc::Sender<proto::Envelope>,
 86    next_message_id: Arc<AtomicU32>,
 87    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 88}
 89
 90impl Peer {
 91    pub fn new() -> Arc<Self> {
 92        Arc::new(Self {
 93            connections: Default::default(),
 94            next_connection_id: Default::default(),
 95            incoming_messages: broadcast::channel(256).0,
 96        })
 97    }
 98
 99    pub async fn add_connection<Conn, H, Fut>(
100        self: &Arc<Self>,
101        conn: Conn,
102        router: Arc<RouterInternal<H>>,
103    ) -> (
104        ConnectionId,
105        impl Future<Output = anyhow::Result<()>> + Send,
106        impl Future<Output = ()>,
107    )
108    where
109        H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
110        Fut: Future<Output = ()>,
111        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
112            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
113            + Send
114            + Unpin,
115    {
116        let (tx, rx) = conn.split();
117        let connection_id = ConnectionId(
118            self.next_connection_id
119                .fetch_add(1, atomic::Ordering::SeqCst),
120        );
121        let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64);
122        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
123        let connection = Connection {
124            outgoing_tx,
125            next_message_id: Default::default(),
126            response_channels: Default::default(),
127        };
128        let mut writer = MessageStream::new(tx);
129        let mut reader = MessageStream::new(rx);
130
131        let handle_io = async move {
132            loop {
133                let read_message = reader.read_message().fuse();
134                futures::pin_mut!(read_message);
135                loop {
136                    futures::select_biased! {
137                        incoming = read_message => match incoming {
138                            Ok(incoming) => {
139                                if incoming_tx.send(incoming).await.is_err() {
140                                    return Ok(());
141                                }
142                                break;
143                            }
144                            Err(error) => {
145                                Err(error).context("received invalid RPC message")?;
146                            }
147                        },
148                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
149                            Some(outgoing) => {
150                                if let Err(result) = writer.write_message(&outgoing).await {
151                                    Err(result).context("failed to write RPC message")?;
152                                }
153                            }
154                            None => return Ok(()),
155                        }
156                    }
157                }
158            }
159        };
160
161        let mut broadcast_incoming_messages = self.incoming_messages.clone();
162        let response_channels = connection.response_channels.clone();
163        let handle_messages = async move {
164            while let Some(envelope) = incoming_rx.recv().await {
165                if let Some(responding_to) = envelope.responding_to {
166                    let channel = response_channels.lock().await.remove(&responding_to);
167                    if let Some(mut tx) = channel {
168                        tx.send(envelope).await.ok();
169                    } else {
170                        log::warn!("received RPC response to unknown request {}", responding_to);
171                    }
172                } else {
173                    router.handle(connection_id, envelope.clone()).await;
174                    if let Some(envelope) = proto::build_typed_envelope(connection_id, envelope) {
175                        broadcast_incoming_messages.send(Arc::from(envelope)).await.ok();
176                    } else {
177                        log::error!("unable to construct a typed envelope");
178                    }
179                }
180            }
181            response_channels.lock().await.clear();
182        };
183
184        self.connections
185            .write()
186            .await
187            .insert(connection_id, connection);
188
189        (connection_id, handle_io, handle_messages)
190    }
191
192    pub async fn disconnect(&self, connection_id: ConnectionId) {
193        self.connections.write().await.remove(&connection_id);
194    }
195
196    pub async fn reset(&self) {
197        self.connections.write().await.clear();
198    }
199
200    pub fn subscribe<T: EnvelopedMessage>(&self) -> impl Stream<Item = Arc<TypedEnvelope<T>>> {
201        self.incoming_messages
202            .subscribe()
203            .filter_map(|envelope| future::ready(Arc::downcast(envelope).ok()))
204    }
205
206    pub fn request<T: RequestMessage>(
207        self: &Arc<Self>,
208        receiver_id: ConnectionId,
209        request: T,
210    ) -> impl Future<Output = Result<T::Response>> {
211        self.request_internal(None, receiver_id, request)
212    }
213
214    pub fn forward_request<T: RequestMessage>(
215        self: &Arc<Self>,
216        sender_id: ConnectionId,
217        receiver_id: ConnectionId,
218        request: T,
219    ) -> impl Future<Output = Result<T::Response>> {
220        self.request_internal(Some(sender_id), receiver_id, request)
221    }
222
223    pub fn request_internal<T: RequestMessage>(
224        self: &Arc<Self>,
225        original_sender_id: Option<ConnectionId>,
226        receiver_id: ConnectionId,
227        request: T,
228    ) -> impl Future<Output = Result<T::Response>> {
229        let this = self.clone();
230        let (tx, mut rx) = mpsc::channel(1);
231        async move {
232            let mut connection = this.connection(receiver_id).await?;
233            let message_id = connection
234                .next_message_id
235                .fetch_add(1, atomic::Ordering::SeqCst);
236            connection
237                .response_channels
238                .lock()
239                .await
240                .insert(message_id, tx);
241            connection
242                .outgoing_tx
243                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
244                .await
245                .map_err(|_| anyhow!("connection was closed"))?;
246            let response = rx
247                .recv()
248                .await
249                .ok_or_else(|| anyhow!("connection was closed"))?;
250            T::Response::from_envelope(response)
251                .ok_or_else(|| anyhow!("received response of the wrong type"))
252        }
253    }
254
255    pub fn send<T: EnvelopedMessage>(
256        self: &Arc<Self>,
257        receiver_id: ConnectionId,
258        message: T,
259    ) -> impl Future<Output = Result<()>> {
260        let this = self.clone();
261        async move {
262            let mut connection = this.connection(receiver_id).await?;
263            let message_id = connection
264                .next_message_id
265                .fetch_add(1, atomic::Ordering::SeqCst);
266            connection
267                .outgoing_tx
268                .send(message.into_envelope(message_id, None, None))
269                .await?;
270            Ok(())
271        }
272    }
273
274    pub fn forward_send<T: EnvelopedMessage>(
275        self: &Arc<Self>,
276        sender_id: ConnectionId,
277        receiver_id: ConnectionId,
278        message: T,
279    ) -> impl Future<Output = Result<()>> {
280        let this = self.clone();
281        async move {
282            let mut connection = this.connection(receiver_id).await?;
283            let message_id = connection
284                .next_message_id
285                .fetch_add(1, atomic::Ordering::SeqCst);
286            connection
287                .outgoing_tx
288                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
289                .await?;
290            Ok(())
291        }
292    }
293
294    pub fn respond<T: RequestMessage>(
295        self: &Arc<Self>,
296        receipt: Receipt<T>,
297        response: T::Response,
298    ) -> impl Future<Output = Result<()>> {
299        let this = self.clone();
300        async move {
301            let mut connection = this.connection(receipt.sender_id).await?;
302            let message_id = connection
303                .next_message_id
304                .fetch_add(1, atomic::Ordering::SeqCst);
305            connection
306                .outgoing_tx
307                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
308                .await?;
309            Ok(())
310        }
311    }
312
313    fn connection(
314        self: &Arc<Self>,
315        connection_id: ConnectionId,
316    ) -> impl Future<Output = Result<Connection>> {
317        let this = self.clone();
318        async move {
319            let connections = this.connections.read().await;
320            let connection = connections
321                .get(&connection_id)
322                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
323            Ok(connection.clone())
324        }
325    }
326}
327
328impl<H, Fut> RouterInternal<H>
329where
330    H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
331    Fut: Future<Output = ()>,
332{
333    pub fn new() -> Self {
334        Self {
335            message_handlers: Default::default(),
336            handler_types: Default::default(),
337        }
338    }
339
340    async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
341        let mut envelope = Some(message);
342        for handler in self.message_handlers.iter() {
343            if let Some(future) = handler(&mut envelope, connection_id) {
344                future.await;
345                return;
346            }
347        }
348        log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
349    }
350}
351
352impl Router {
353    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
354    where
355        T: EnvelopedMessage,
356        Fut: 'static + Send + Future<Output = Result<()>>,
357        F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
358    {
359        if !self.handler_types.insert(TypeId::of::<T>()) {
360            panic!("duplicate handler type");
361        }
362
363        self.message_handlers
364            .push(Box::new(move |envelope, connection_id| {
365                if envelope.as_ref().map_or(false, T::matches_envelope) {
366                    let envelope = Option::take(envelope).unwrap();
367                    let message_id = envelope.id;
368                    let future = handler(TypedEnvelope {
369                        sender_id: connection_id,
370                        original_sender_id: envelope.original_sender_id.map(PeerId),
371                        message_id,
372                        payload: T::from_envelope(envelope).unwrap(),
373                    });
374                    Some(
375                        async move {
376                            if let Err(error) = future.await {
377                                log::error!(
378                                    "error handling message {} {}: {:?}",
379                                    T::NAME,
380                                    message_id,
381                                    error
382                                );
383                            }
384                        }
385                        .boxed(),
386                    )
387                } else {
388                    None
389                }
390            }));
391    }
392}
393
394impl ForegroundRouter {
395    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
396    where
397        T: EnvelopedMessage,
398        Fut: 'static + Future<Output = Result<()>>,
399        F: 'static + Fn(TypedEnvelope<T>) -> Fut,
400    {
401        if !self.handler_types.insert(TypeId::of::<T>()) {
402            panic!("duplicate handler type");
403        }
404
405        self.message_handlers
406            .push(Box::new(move |envelope, connection_id| {
407                if envelope.as_ref().map_or(false, T::matches_envelope) {
408                    let envelope = Option::take(envelope).unwrap();
409                    let message_id = envelope.id;
410                    let future = handler(TypedEnvelope {
411                        sender_id: connection_id,
412                        original_sender_id: envelope.original_sender_id.map(PeerId),
413                        message_id,
414                        payload: T::from_envelope(envelope).unwrap(),
415                    });
416                    Some(
417                        async move {
418                            if let Err(error) = future.await {
419                                log::error!(
420                                    "error handling message {} {}: {:?}",
421                                    T::NAME,
422                                    message_id,
423                                    error
424                                );
425                            }
426                        }
427                        .boxed_local(),
428                    )
429                } else {
430                    None
431                }
432            }));
433    }
434}
435
436impl<T> Clone for Receipt<T> {
437    fn clone(&self) -> Self {
438        Self {
439            sender_id: self.sender_id,
440            message_id: self.message_id,
441            payload_type: PhantomData,
442        }
443    }
444}
445
446impl<T> Copy for Receipt<T> {}
447
448impl fmt::Display for ConnectionId {
449    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
450        self.0.fmt(f)
451    }
452}
453
454impl fmt::Display for PeerId {
455    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
456        self.0.fmt(f)
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463    use crate::test;
464
465    #[test]
466    fn test_request_response() {
467        smol::block_on(async move {
468            // create 2 clients connected to 1 server
469            let server = Peer::new();
470            let client1 = Peer::new();
471            let client2 = Peer::new();
472
473            let mut router = Router::new();
474            router.add_message_handler({
475                let server = server.clone();
476                move |envelope: TypedEnvelope<proto::Auth>| {
477                    let server = server.clone();
478                    async move {
479                        let receipt = envelope.receipt();
480                        let message = envelope.payload;
481                        server
482                            .respond(
483                                receipt,
484                                match message.user_id {
485                                    1 => {
486                                        assert_eq!(message.access_token, "access-token-1");
487                                        proto::AuthResponse {
488                                            credentials_valid: true,
489                                        }
490                                    }
491                                    2 => {
492                                        assert_eq!(message.access_token, "access-token-2");
493                                        proto::AuthResponse {
494                                            credentials_valid: false,
495                                        }
496                                    }
497                                    _ => {
498                                        panic!("unexpected user id {}", message.user_id);
499                                    }
500                                },
501                            )
502                            .await
503                    }
504                }
505            });
506
507            router.add_message_handler({
508                let server = server.clone();
509                move |envelope: TypedEnvelope<proto::OpenBuffer>| {
510                    let server = server.clone();
511                    async move {
512                        let receipt = envelope.receipt();
513                        let message = envelope.payload;
514                        server
515                            .respond(
516                                receipt,
517                                match message.path.as_str() {
518                                    "path/one" => {
519                                        assert_eq!(message.worktree_id, 1);
520                                        proto::OpenBufferResponse {
521                                            buffer: Some(proto::Buffer {
522                                                id: 101,
523                                                content: "path/one content".to_string(),
524                                                history: vec![],
525                                                selections: vec![],
526                                            }),
527                                        }
528                                    }
529                                    "path/two" => {
530                                        assert_eq!(message.worktree_id, 2);
531                                        proto::OpenBufferResponse {
532                                            buffer: Some(proto::Buffer {
533                                                id: 102,
534                                                content: "path/two content".to_string(),
535                                                history: vec![],
536                                                selections: vec![],
537                                            }),
538                                        }
539                                    }
540                                    _ => {
541                                        panic!("unexpected path {}", message.path);
542                                    }
543                                },
544                            )
545                            .await
546                    }
547                }
548            });
549            let router = Arc::new(router);
550
551            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
552            let (client1_conn_id, io_task1, msg_task1) = client1
553                .add_connection(client1_to_server_conn, router.clone())
554                .await;
555            let (_, io_task2, msg_task2) = server
556                .add_connection(server_to_client_1_conn, router.clone())
557                .await;
558
559            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
560            let (client2_conn_id, io_task3, msg_task3) = client2
561                .add_connection(client2_to_server_conn, router.clone())
562                .await;
563            let (_, io_task4, msg_task4) = server
564                .add_connection(server_to_client_2_conn, router.clone())
565                .await;
566
567            smol::spawn(io_task1).detach();
568            smol::spawn(io_task2).detach();
569            smol::spawn(io_task3).detach();
570            smol::spawn(io_task4).detach();
571            smol::spawn(msg_task1).detach();
572            smol::spawn(msg_task2).detach();
573            smol::spawn(msg_task3).detach();
574            smol::spawn(msg_task4).detach();
575
576            assert_eq!(
577                client1
578                    .request(
579                        client1_conn_id,
580                        proto::Auth {
581                            user_id: 1,
582                            access_token: "access-token-1".to_string(),
583                        },
584                    )
585                    .await
586                    .unwrap(),
587                proto::AuthResponse {
588                    credentials_valid: true,
589                }
590            );
591
592            assert_eq!(
593                client2
594                    .request(
595                        client2_conn_id,
596                        proto::Auth {
597                            user_id: 2,
598                            access_token: "access-token-2".to_string(),
599                        },
600                    )
601                    .await
602                    .unwrap(),
603                proto::AuthResponse {
604                    credentials_valid: false,
605                }
606            );
607
608            assert_eq!(
609                client1
610                    .request(
611                        client1_conn_id,
612                        proto::OpenBuffer {
613                            worktree_id: 1,
614                            path: "path/one".to_string(),
615                        },
616                    )
617                    .await
618                    .unwrap(),
619                proto::OpenBufferResponse {
620                    buffer: Some(proto::Buffer {
621                        id: 101,
622                        content: "path/one content".to_string(),
623                        history: vec![],
624                        selections: vec![],
625                    }),
626                }
627            );
628
629            assert_eq!(
630                client2
631                    .request(
632                        client2_conn_id,
633                        proto::OpenBuffer {
634                            worktree_id: 2,
635                            path: "path/two".to_string(),
636                        },
637                    )
638                    .await
639                    .unwrap(),
640                proto::OpenBufferResponse {
641                    buffer: Some(proto::Buffer {
642                        id: 102,
643                        content: "path/two content".to_string(),
644                        history: vec![],
645                        selections: vec![],
646                    }),
647                }
648            );
649
650            client1.disconnect(client1_conn_id).await;
651            client2.disconnect(client1_conn_id).await;
652        });
653    }
654
655    #[test]
656    fn test_disconnect() {
657        smol::block_on(async move {
658            let (client_conn, mut server_conn) = test::Channel::bidirectional();
659
660            let client = Peer::new();
661            let router = Arc::new(Router::new());
662            let (connection_id, io_handler, message_handler) =
663                client.add_connection(client_conn, router).await;
664
665            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
666            smol::spawn(async move {
667                io_handler.await.ok();
668                io_ended_tx.send(()).await.unwrap();
669            })
670            .detach();
671
672            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
673            smol::spawn(async move {
674                message_handler.await;
675                messages_ended_tx.send(()).await.unwrap();
676            })
677            .detach();
678
679            client.disconnect(connection_id).await;
680
681            io_ended_rx.recv().await;
682            messages_ended_rx.recv().await;
683            assert!(
684                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
685                    .await
686                    .is_err()
687            );
688        });
689    }
690
691    #[test]
692    fn test_io_error() {
693        smol::block_on(async move {
694            let (client_conn, server_conn) = test::Channel::bidirectional();
695            drop(server_conn);
696
697            let client = Peer::new();
698            let router = Arc::new(Router::new());
699            let (connection_id, io_handler, message_handler) =
700                client.add_connection(client_conn, router).await;
701            smol::spawn(io_handler).detach();
702            smol::spawn(message_handler).detach();
703
704            let err = client
705                .request(
706                    connection_id,
707                    proto::Auth {
708                        user_id: 42,
709                        access_token: "token".to_string(),
710                    },
711                )
712                .await
713                .unwrap_err();
714            assert_eq!(err.to_string(), "connection was closed");
715        });
716    }
717}