peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Context, Result};
  3use async_lock::{Mutex, RwLock};
  4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  5use futures::{
  6    future::{self, BoxFuture, LocalBoxFuture},
  7    FutureExt, Stream, StreamExt,
  8};
  9use postage::{
 10    broadcast, mpsc,
 11    prelude::{Sink as _, Stream as _},
 12};
 13use std::{
 14    any::{Any, TypeId},
 15    collections::{HashMap, HashSet},
 16    fmt,
 17    future::Future,
 18    marker::PhantomData,
 19    sync::{
 20        atomic::{self, AtomicU32},
 21        Arc,
 22    },
 23};
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 26pub struct ConnectionId(pub u32);
 27
 28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 29pub struct PeerId(pub u32);
 30
 31type MessageHandler = Box<
 32    dyn Send
 33        + Sync
 34        + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
 35>;
 36
 37type ForegroundMessageHandler =
 38    Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
 39
 40pub struct Receipt<T> {
 41    sender_id: ConnectionId,
 42    message_id: u32,
 43    payload_type: PhantomData<T>,
 44}
 45
 46pub struct TypedEnvelope<T> {
 47    pub sender_id: ConnectionId,
 48    pub original_sender_id: Option<PeerId>,
 49    pub message_id: u32,
 50    pub payload: T,
 51}
 52
 53impl<T> TypedEnvelope<T> {
 54    pub fn original_sender_id(&self) -> Result<PeerId> {
 55        self.original_sender_id
 56            .ok_or_else(|| anyhow!("missing original_sender_id"))
 57    }
 58}
 59
 60impl<T: RequestMessage> TypedEnvelope<T> {
 61    pub fn receipt(&self) -> Receipt<T> {
 62        Receipt {
 63            sender_id: self.sender_id,
 64            message_id: self.message_id,
 65            payload_type: PhantomData,
 66        }
 67    }
 68}
 69
 70pub type Router = RouterInternal<MessageHandler>;
 71pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
 72pub struct RouterInternal<H> {
 73    message_handlers: Vec<H>,
 74    handler_types: HashSet<TypeId>,
 75}
 76
 77pub struct Peer {
 78    connections: RwLock<HashMap<ConnectionId, Connection>>,
 79    next_connection_id: AtomicU32,
 80    incoming_messages: broadcast::Sender<Arc<dyn Any + Send + Sync>>,
 81}
 82
 83#[derive(Clone)]
 84struct Connection {
 85    outgoing_tx: mpsc::Sender<proto::Envelope>,
 86    next_message_id: Arc<AtomicU32>,
 87    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 88}
 89
 90impl Peer {
 91    pub fn new() -> Arc<Self> {
 92        Arc::new(Self {
 93            connections: Default::default(),
 94            next_connection_id: Default::default(),
 95            incoming_messages: broadcast::channel(256).0,
 96        })
 97    }
 98
 99    pub async fn add_connection<Conn, H, Fut>(
100        self: &Arc<Self>,
101        conn: Conn,
102        router: Arc<RouterInternal<H>>,
103    ) -> (
104        ConnectionId,
105        impl Future<Output = anyhow::Result<()>> + Send,
106        impl Future<Output = ()>,
107    )
108    where
109        H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
110        Fut: Future<Output = ()>,
111        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
112            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
113            + Send
114            + Unpin,
115    {
116        let (tx, rx) = conn.split();
117        let connection_id = ConnectionId(
118            self.next_connection_id
119                .fetch_add(1, atomic::Ordering::SeqCst),
120        );
121        let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64);
122        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
123        let connection = Connection {
124            outgoing_tx,
125            next_message_id: Default::default(),
126            response_channels: Default::default(),
127        };
128        let mut writer = MessageStream::new(tx);
129        let mut reader = MessageStream::new(rx);
130
131        let handle_io = async move {
132            loop {
133                let read_message = reader.read_message().fuse();
134                futures::pin_mut!(read_message);
135                loop {
136                    futures::select_biased! {
137                        incoming = read_message => match incoming {
138                            Ok(incoming) => {
139                                if incoming_tx.send(incoming).await.is_err() {
140                                    return Ok(());
141                                }
142                                break;
143                            }
144                            Err(error) => {
145                                Err(error).context("received invalid RPC message")?;
146                            }
147                        },
148                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
149                            Some(outgoing) => {
150                                if let Err(result) = writer.write_message(&outgoing).await {
151                                    Err(result).context("failed to write RPC message")?;
152                                }
153                            }
154                            None => return Ok(()),
155                        }
156                    }
157                }
158            }
159        };
160
161        let mut broadcast_incoming_messages = self.incoming_messages.clone();
162        let response_channels = connection.response_channels.clone();
163        let handle_messages = async move {
164            while let Some(envelope) = incoming_rx.recv().await {
165                if let Some(responding_to) = envelope.responding_to {
166                    let channel = response_channels.lock().await.remove(&responding_to);
167                    if let Some(mut tx) = channel {
168                        tx.send(envelope).await.ok();
169                    } else {
170                        log::warn!("received RPC response to unknown request {}", responding_to);
171                    }
172                } else {
173                    router.handle(connection_id, envelope.clone()).await;
174                    match proto::build_typed_envelope(connection_id, envelope) {
175                        Ok(envelope) => {
176                            broadcast_incoming_messages.send(envelope).await.ok();
177                        }
178                        Err(error) => log::error!("{}", error),
179                    }
180                }
181            }
182            response_channels.lock().await.clear();
183        };
184
185        self.connections
186            .write()
187            .await
188            .insert(connection_id, connection);
189
190        (connection_id, handle_io, handle_messages)
191    }
192
193    pub async fn disconnect(&self, connection_id: ConnectionId) {
194        self.connections.write().await.remove(&connection_id);
195    }
196
197    pub async fn reset(&self) {
198        self.connections.write().await.clear();
199    }
200
201    pub fn subscribe<T: EnvelopedMessage>(&self) -> impl Stream<Item = Arc<TypedEnvelope<T>>> {
202        self.incoming_messages
203            .subscribe()
204            .filter_map(|envelope| future::ready(Arc::downcast(envelope).ok()))
205    }
206
207    pub fn request<T: RequestMessage>(
208        self: &Arc<Self>,
209        receiver_id: ConnectionId,
210        request: T,
211    ) -> impl Future<Output = Result<T::Response>> {
212        self.request_internal(None, receiver_id, request)
213    }
214
215    pub fn forward_request<T: RequestMessage>(
216        self: &Arc<Self>,
217        sender_id: ConnectionId,
218        receiver_id: ConnectionId,
219        request: T,
220    ) -> impl Future<Output = Result<T::Response>> {
221        self.request_internal(Some(sender_id), receiver_id, request)
222    }
223
224    pub fn request_internal<T: RequestMessage>(
225        self: &Arc<Self>,
226        original_sender_id: Option<ConnectionId>,
227        receiver_id: ConnectionId,
228        request: T,
229    ) -> impl Future<Output = Result<T::Response>> {
230        let this = self.clone();
231        let (tx, mut rx) = mpsc::channel(1);
232        async move {
233            let mut connection = this.connection(receiver_id).await?;
234            let message_id = connection
235                .next_message_id
236                .fetch_add(1, atomic::Ordering::SeqCst);
237            connection
238                .response_channels
239                .lock()
240                .await
241                .insert(message_id, tx);
242            connection
243                .outgoing_tx
244                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
245                .await
246                .map_err(|_| anyhow!("connection was closed"))?;
247            let response = rx
248                .recv()
249                .await
250                .ok_or_else(|| anyhow!("connection was closed"))?;
251            T::Response::from_envelope(response)
252                .ok_or_else(|| anyhow!("received response of the wrong type"))
253        }
254    }
255
256    pub fn send<T: EnvelopedMessage>(
257        self: &Arc<Self>,
258        receiver_id: ConnectionId,
259        message: T,
260    ) -> impl Future<Output = Result<()>> {
261        let this = self.clone();
262        async move {
263            let mut connection = this.connection(receiver_id).await?;
264            let message_id = connection
265                .next_message_id
266                .fetch_add(1, atomic::Ordering::SeqCst);
267            connection
268                .outgoing_tx
269                .send(message.into_envelope(message_id, None, None))
270                .await?;
271            Ok(())
272        }
273    }
274
275    pub fn forward_send<T: EnvelopedMessage>(
276        self: &Arc<Self>,
277        sender_id: ConnectionId,
278        receiver_id: ConnectionId,
279        message: T,
280    ) -> impl Future<Output = Result<()>> {
281        let this = self.clone();
282        async move {
283            let mut connection = this.connection(receiver_id).await?;
284            let message_id = connection
285                .next_message_id
286                .fetch_add(1, atomic::Ordering::SeqCst);
287            connection
288                .outgoing_tx
289                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
290                .await?;
291            Ok(())
292        }
293    }
294
295    pub fn respond<T: RequestMessage>(
296        self: &Arc<Self>,
297        receipt: Receipt<T>,
298        response: T::Response,
299    ) -> impl Future<Output = Result<()>> {
300        let this = self.clone();
301        async move {
302            let mut connection = this.connection(receipt.sender_id).await?;
303            let message_id = connection
304                .next_message_id
305                .fetch_add(1, atomic::Ordering::SeqCst);
306            connection
307                .outgoing_tx
308                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
309                .await?;
310            Ok(())
311        }
312    }
313
314    fn connection(
315        self: &Arc<Self>,
316        connection_id: ConnectionId,
317    ) -> impl Future<Output = Result<Connection>> {
318        let this = self.clone();
319        async move {
320            let connections = this.connections.read().await;
321            let connection = connections
322                .get(&connection_id)
323                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
324            Ok(connection.clone())
325        }
326    }
327}
328
329impl<H, Fut> RouterInternal<H>
330where
331    H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
332    Fut: Future<Output = ()>,
333{
334    pub fn new() -> Self {
335        Self {
336            message_handlers: Default::default(),
337            handler_types: Default::default(),
338        }
339    }
340
341    async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
342        let mut envelope = Some(message);
343        for handler in self.message_handlers.iter() {
344            if let Some(future) = handler(&mut envelope, connection_id) {
345                future.await;
346                return;
347            }
348        }
349        log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
350    }
351}
352
353impl Router {
354    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
355    where
356        T: EnvelopedMessage,
357        Fut: 'static + Send + Future<Output = Result<()>>,
358        F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
359    {
360        if !self.handler_types.insert(TypeId::of::<T>()) {
361            panic!("duplicate handler type");
362        }
363
364        self.message_handlers
365            .push(Box::new(move |envelope, connection_id| {
366                if envelope.as_ref().map_or(false, T::matches_envelope) {
367                    let envelope = Option::take(envelope).unwrap();
368                    let message_id = envelope.id;
369                    let future = handler(TypedEnvelope {
370                        sender_id: connection_id,
371                        original_sender_id: envelope.original_sender_id.map(PeerId),
372                        message_id,
373                        payload: T::from_envelope(envelope).unwrap(),
374                    });
375                    Some(
376                        async move {
377                            if let Err(error) = future.await {
378                                log::error!(
379                                    "error handling message {} {}: {:?}",
380                                    T::NAME,
381                                    message_id,
382                                    error
383                                );
384                            }
385                        }
386                        .boxed(),
387                    )
388                } else {
389                    None
390                }
391            }));
392    }
393}
394
395impl ForegroundRouter {
396    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
397    where
398        T: EnvelopedMessage,
399        Fut: 'static + Future<Output = Result<()>>,
400        F: 'static + Fn(TypedEnvelope<T>) -> Fut,
401    {
402        if !self.handler_types.insert(TypeId::of::<T>()) {
403            panic!("duplicate handler type");
404        }
405
406        self.message_handlers
407            .push(Box::new(move |envelope, connection_id| {
408                if envelope.as_ref().map_or(false, T::matches_envelope) {
409                    let envelope = Option::take(envelope).unwrap();
410                    let message_id = envelope.id;
411                    let future = handler(TypedEnvelope {
412                        sender_id: connection_id,
413                        original_sender_id: envelope.original_sender_id.map(PeerId),
414                        message_id,
415                        payload: T::from_envelope(envelope).unwrap(),
416                    });
417                    Some(
418                        async move {
419                            if let Err(error) = future.await {
420                                log::error!(
421                                    "error handling message {} {}: {:?}",
422                                    T::NAME,
423                                    message_id,
424                                    error
425                                );
426                            }
427                        }
428                        .boxed_local(),
429                    )
430                } else {
431                    None
432                }
433            }));
434    }
435}
436
437impl<T> Clone for Receipt<T> {
438    fn clone(&self) -> Self {
439        Self {
440            sender_id: self.sender_id,
441            message_id: self.message_id,
442            payload_type: PhantomData,
443        }
444    }
445}
446
447impl<T> Copy for Receipt<T> {}
448
449impl fmt::Display for ConnectionId {
450    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451        self.0.fmt(f)
452    }
453}
454
455impl fmt::Display for PeerId {
456    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457        self.0.fmt(f)
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use crate::test;
465
466    #[test]
467    fn test_request_response() {
468        smol::block_on(async move {
469            // create 2 clients connected to 1 server
470            let server = Peer::new();
471            let client1 = Peer::new();
472            let client2 = Peer::new();
473
474            let mut router = Router::new();
475            router.add_message_handler({
476                let server = server.clone();
477                move |envelope: TypedEnvelope<proto::Auth>| {
478                    let server = server.clone();
479                    async move {
480                        let receipt = envelope.receipt();
481                        let message = envelope.payload;
482                        server
483                            .respond(
484                                receipt,
485                                match message.user_id {
486                                    1 => {
487                                        assert_eq!(message.access_token, "access-token-1");
488                                        proto::AuthResponse {
489                                            credentials_valid: true,
490                                        }
491                                    }
492                                    2 => {
493                                        assert_eq!(message.access_token, "access-token-2");
494                                        proto::AuthResponse {
495                                            credentials_valid: false,
496                                        }
497                                    }
498                                    _ => {
499                                        panic!("unexpected user id {}", message.user_id);
500                                    }
501                                },
502                            )
503                            .await
504                    }
505                }
506            });
507
508            router.add_message_handler({
509                let server = server.clone();
510                move |envelope: TypedEnvelope<proto::OpenBuffer>| {
511                    let server = server.clone();
512                    async move {
513                        let receipt = envelope.receipt();
514                        let message = envelope.payload;
515                        server
516                            .respond(
517                                receipt,
518                                match message.path.as_str() {
519                                    "path/one" => {
520                                        assert_eq!(message.worktree_id, 1);
521                                        proto::OpenBufferResponse {
522                                            buffer: Some(proto::Buffer {
523                                                id: 101,
524                                                content: "path/one content".to_string(),
525                                                history: vec![],
526                                                selections: vec![],
527                                            }),
528                                        }
529                                    }
530                                    "path/two" => {
531                                        assert_eq!(message.worktree_id, 2);
532                                        proto::OpenBufferResponse {
533                                            buffer: Some(proto::Buffer {
534                                                id: 102,
535                                                content: "path/two content".to_string(),
536                                                history: vec![],
537                                                selections: vec![],
538                                            }),
539                                        }
540                                    }
541                                    _ => {
542                                        panic!("unexpected path {}", message.path);
543                                    }
544                                },
545                            )
546                            .await
547                    }
548                }
549            });
550            let router = Arc::new(router);
551
552            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
553            let (client1_conn_id, io_task1, msg_task1) = client1
554                .add_connection(client1_to_server_conn, router.clone())
555                .await;
556            let (_, io_task2, msg_task2) = server
557                .add_connection(server_to_client_1_conn, router.clone())
558                .await;
559
560            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
561            let (client2_conn_id, io_task3, msg_task3) = client2
562                .add_connection(client2_to_server_conn, router.clone())
563                .await;
564            let (_, io_task4, msg_task4) = server
565                .add_connection(server_to_client_2_conn, router.clone())
566                .await;
567
568            smol::spawn(io_task1).detach();
569            smol::spawn(io_task2).detach();
570            smol::spawn(io_task3).detach();
571            smol::spawn(io_task4).detach();
572            smol::spawn(msg_task1).detach();
573            smol::spawn(msg_task2).detach();
574            smol::spawn(msg_task3).detach();
575            smol::spawn(msg_task4).detach();
576
577            assert_eq!(
578                client1
579                    .request(
580                        client1_conn_id,
581                        proto::Auth {
582                            user_id: 1,
583                            access_token: "access-token-1".to_string(),
584                        },
585                    )
586                    .await
587                    .unwrap(),
588                proto::AuthResponse {
589                    credentials_valid: true,
590                }
591            );
592
593            assert_eq!(
594                client2
595                    .request(
596                        client2_conn_id,
597                        proto::Auth {
598                            user_id: 2,
599                            access_token: "access-token-2".to_string(),
600                        },
601                    )
602                    .await
603                    .unwrap(),
604                proto::AuthResponse {
605                    credentials_valid: false,
606                }
607            );
608
609            assert_eq!(
610                client1
611                    .request(
612                        client1_conn_id,
613                        proto::OpenBuffer {
614                            worktree_id: 1,
615                            path: "path/one".to_string(),
616                        },
617                    )
618                    .await
619                    .unwrap(),
620                proto::OpenBufferResponse {
621                    buffer: Some(proto::Buffer {
622                        id: 101,
623                        content: "path/one content".to_string(),
624                        history: vec![],
625                        selections: vec![],
626                    }),
627                }
628            );
629
630            assert_eq!(
631                client2
632                    .request(
633                        client2_conn_id,
634                        proto::OpenBuffer {
635                            worktree_id: 2,
636                            path: "path/two".to_string(),
637                        },
638                    )
639                    .await
640                    .unwrap(),
641                proto::OpenBufferResponse {
642                    buffer: Some(proto::Buffer {
643                        id: 102,
644                        content: "path/two content".to_string(),
645                        history: vec![],
646                        selections: vec![],
647                    }),
648                }
649            );
650
651            client1.disconnect(client1_conn_id).await;
652            client2.disconnect(client1_conn_id).await;
653        });
654    }
655
656    #[test]
657    fn test_disconnect() {
658        smol::block_on(async move {
659            let (client_conn, mut server_conn) = test::Channel::bidirectional();
660
661            let client = Peer::new();
662            let router = Arc::new(Router::new());
663            let (connection_id, io_handler, message_handler) =
664                client.add_connection(client_conn, router).await;
665
666            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
667            smol::spawn(async move {
668                io_handler.await.ok();
669                io_ended_tx.send(()).await.unwrap();
670            })
671            .detach();
672
673            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
674            smol::spawn(async move {
675                message_handler.await;
676                messages_ended_tx.send(()).await.unwrap();
677            })
678            .detach();
679
680            client.disconnect(connection_id).await;
681
682            io_ended_rx.recv().await;
683            messages_ended_rx.recv().await;
684            assert!(
685                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
686                    .await
687                    .is_err()
688            );
689        });
690    }
691
692    #[test]
693    fn test_io_error() {
694        smol::block_on(async move {
695            let (client_conn, server_conn) = test::Channel::bidirectional();
696            drop(server_conn);
697
698            let client = Peer::new();
699            let router = Arc::new(Router::new());
700            let (connection_id, io_handler, message_handler) =
701                client.add_connection(client_conn, router).await;
702            smol::spawn(io_handler).detach();
703            smol::spawn(message_handler).detach();
704
705            let err = client
706                .request(
707                    connection_id,
708                    proto::Auth {
709                        user_id: 42,
710                        access_token: "token".to_string(),
711                    },
712                )
713                .await
714                .unwrap_err();
715            assert_eq!(err.to_string(), "connection was closed");
716        });
717    }
718}