peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Context, Result};
  3use async_lock::{Mutex, RwLock};
  4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  5use futures::{
  6    future::{BoxFuture, LocalBoxFuture},
  7    stream::{SplitSink, SplitStream},
  8    FutureExt, StreamExt,
  9};
 10use postage::{
 11    mpsc,
 12    prelude::{Sink, Stream},
 13};
 14use std::{
 15    any::TypeId,
 16    collections::{HashMap, HashSet},
 17    fmt,
 18    future::Future,
 19    marker::PhantomData,
 20    sync::{
 21        atomic::{self, AtomicU32},
 22        Arc,
 23    },
 24};
 25
 26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 27pub struct ConnectionId(pub u32);
 28
 29#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 30pub struct PeerId(pub u32);
 31
 32type MessageHandler = Box<
 33    dyn Send
 34        + Sync
 35        + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
 36>;
 37
 38type ForegroundMessageHandler =
 39    Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
 40
 41pub struct Receipt<T> {
 42    sender_id: ConnectionId,
 43    message_id: u32,
 44    payload_type: PhantomData<T>,
 45}
 46
 47pub struct TypedEnvelope<T> {
 48    pub sender_id: ConnectionId,
 49    original_sender_id: Option<PeerId>,
 50    pub message_id: u32,
 51    pub payload: T,
 52}
 53
 54impl<T> TypedEnvelope<T> {
 55    pub fn original_sender_id(&self) -> Result<PeerId> {
 56        self.original_sender_id
 57            .ok_or_else(|| anyhow!("missing original_sender_id"))
 58    }
 59}
 60
 61impl<T: RequestMessage> TypedEnvelope<T> {
 62    pub fn receipt(&self) -> Receipt<T> {
 63        Receipt {
 64            sender_id: self.sender_id,
 65            message_id: self.message_id,
 66            payload_type: PhantomData,
 67        }
 68    }
 69}
 70
 71pub type Router = RouterInternal<MessageHandler>;
 72pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
 73pub struct RouterInternal<H> {
 74    message_handlers: Vec<H>,
 75    handler_types: HashSet<TypeId>,
 76}
 77
 78pub struct Peer {
 79    connections: RwLock<HashMap<ConnectionId, Connection>>,
 80    next_connection_id: AtomicU32,
 81}
 82
 83#[derive(Clone)]
 84struct Connection {
 85    outgoing_tx: mpsc::Sender<proto::Envelope>,
 86    next_message_id: Arc<AtomicU32>,
 87    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 88}
 89
 90pub struct IOHandler<W, R> {
 91    connection_id: ConnectionId,
 92    incoming_tx: mpsc::Sender<proto::Envelope>,
 93    outgoing_rx: mpsc::Receiver<proto::Envelope>,
 94    writer: MessageStream<W>,
 95    reader: MessageStream<R>,
 96}
 97
 98impl Peer {
 99    pub fn new() -> Arc<Self> {
100        Arc::new(Self {
101            connections: Default::default(),
102            next_connection_id: Default::default(),
103        })
104    }
105
106    pub async fn add_connection<Conn, H, Fut>(
107        self: &Arc<Self>,
108        conn: Conn,
109        router: Arc<RouterInternal<H>>,
110    ) -> (
111        ConnectionId,
112        IOHandler<SplitSink<Conn, WebSocketMessage>, SplitStream<Conn>>,
113        impl Future<Output = anyhow::Result<()>>,
114    )
115    where
116        H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
117        Fut: Future<Output = ()>,
118        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
119            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
120            + Unpin,
121    {
122        let (tx, rx) = conn.split();
123        let connection_id = ConnectionId(
124            self.next_connection_id
125                .fetch_add(1, atomic::Ordering::SeqCst),
126        );
127        let (incoming_tx, mut incoming_rx) = mpsc::channel(64);
128        let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
129        let connection = Connection {
130            outgoing_tx,
131            next_message_id: Default::default(),
132            response_channels: Default::default(),
133        };
134        let handle_io = IOHandler {
135            connection_id,
136            outgoing_rx,
137            incoming_tx,
138            writer: MessageStream::new(tx),
139            reader: MessageStream::new(rx),
140        };
141
142        let response_channels = connection.response_channels.clone();
143        let handle_messages = async move {
144            while let Some(message) = incoming_rx.recv().await {
145                if let Some(responding_to) = message.responding_to {
146                    let channel = response_channels.lock().await.remove(&responding_to);
147                    if let Some(mut tx) = channel {
148                        tx.send(message).await.ok();
149                    } else {
150                        log::warn!("received RPC response to unknown request {}", responding_to);
151                    }
152                } else {
153                    router.handle(connection_id, message).await;
154                }
155            }
156            response_channels.lock().await.clear();
157            Ok(())
158        };
159
160        self.connections
161            .write()
162            .await
163            .insert(connection_id, connection);
164
165        (connection_id, handle_io, handle_messages)
166    }
167
168    pub async fn disconnect(&self, connection_id: ConnectionId) {
169        self.connections.write().await.remove(&connection_id);
170    }
171
172    pub async fn reset(&self) {
173        self.connections.write().await.clear();
174    }
175
176    pub fn request<T: RequestMessage>(
177        self: &Arc<Self>,
178        receiver_id: ConnectionId,
179        request: T,
180    ) -> impl Future<Output = Result<T::Response>> {
181        self.request_internal(None, receiver_id, request)
182    }
183
184    pub fn forward_request<T: RequestMessage>(
185        self: &Arc<Self>,
186        sender_id: ConnectionId,
187        receiver_id: ConnectionId,
188        request: T,
189    ) -> impl Future<Output = Result<T::Response>> {
190        self.request_internal(Some(sender_id), receiver_id, request)
191    }
192
193    pub fn request_internal<T: RequestMessage>(
194        self: &Arc<Self>,
195        original_sender_id: Option<ConnectionId>,
196        receiver_id: ConnectionId,
197        request: T,
198    ) -> impl Future<Output = Result<T::Response>> {
199        let this = self.clone();
200        let (tx, mut rx) = mpsc::channel(1);
201        async move {
202            let mut connection = this.connection(receiver_id).await?;
203            let message_id = connection
204                .next_message_id
205                .fetch_add(1, atomic::Ordering::SeqCst);
206            connection
207                .response_channels
208                .lock()
209                .await
210                .insert(message_id, tx);
211            connection
212                .outgoing_tx
213                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
214                .await?;
215            let response = rx
216                .recv()
217                .await
218                .ok_or_else(|| anyhow!("connection was closed"))?;
219            T::Response::from_envelope(response)
220                .ok_or_else(|| anyhow!("received response of the wrong type"))
221        }
222    }
223
224    pub fn send<T: EnvelopedMessage>(
225        self: &Arc<Self>,
226        receiver_id: ConnectionId,
227        message: T,
228    ) -> impl Future<Output = Result<()>> {
229        let this = self.clone();
230        async move {
231            let mut connection = this.connection(receiver_id).await?;
232            let message_id = connection
233                .next_message_id
234                .fetch_add(1, atomic::Ordering::SeqCst);
235            connection
236                .outgoing_tx
237                .send(message.into_envelope(message_id, None, None))
238                .await?;
239            Ok(())
240        }
241    }
242
243    pub fn forward_send<T: EnvelopedMessage>(
244        self: &Arc<Self>,
245        sender_id: ConnectionId,
246        receiver_id: ConnectionId,
247        message: T,
248    ) -> impl Future<Output = Result<()>> {
249        let this = self.clone();
250        async move {
251            let mut connection = this.connection(receiver_id).await?;
252            let message_id = connection
253                .next_message_id
254                .fetch_add(1, atomic::Ordering::SeqCst);
255            connection
256                .outgoing_tx
257                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
258                .await?;
259            Ok(())
260        }
261    }
262
263    pub fn respond<T: RequestMessage>(
264        self: &Arc<Self>,
265        receipt: Receipt<T>,
266        response: T::Response,
267    ) -> impl Future<Output = Result<()>> {
268        let this = self.clone();
269        async move {
270            let mut connection = this.connection(receipt.sender_id).await?;
271            let message_id = connection
272                .next_message_id
273                .fetch_add(1, atomic::Ordering::SeqCst);
274            connection
275                .outgoing_tx
276                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
277                .await?;
278            Ok(())
279        }
280    }
281
282    fn connection(
283        self: &Arc<Self>,
284        connection_id: ConnectionId,
285    ) -> impl Future<Output = Result<Connection>> {
286        let this = self.clone();
287        async move {
288            let connections = this.connections.read().await;
289            let connection = connections
290                .get(&connection_id)
291                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
292            Ok(connection.clone())
293        }
294    }
295}
296
297impl<H, Fut> RouterInternal<H>
298where
299    H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
300    Fut: Future<Output = ()>,
301{
302    pub fn new() -> Self {
303        Self {
304            message_handlers: Default::default(),
305            handler_types: Default::default(),
306        }
307    }
308
309    async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
310        let mut envelope = Some(message);
311        for handler in self.message_handlers.iter() {
312            if let Some(future) = handler(&mut envelope, connection_id) {
313                future.await;
314                return;
315            }
316        }
317        log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
318    }
319}
320
321impl Router {
322    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
323    where
324        T: EnvelopedMessage,
325        Fut: 'static + Send + Future<Output = Result<()>>,
326        F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
327    {
328        if !self.handler_types.insert(TypeId::of::<T>()) {
329            panic!("duplicate handler type");
330        }
331
332        self.message_handlers
333            .push(Box::new(move |envelope, connection_id| {
334                if envelope.as_ref().map_or(false, T::matches_envelope) {
335                    let envelope = Option::take(envelope).unwrap();
336                    let message_id = envelope.id;
337                    let future = handler(TypedEnvelope {
338                        sender_id: connection_id,
339                        original_sender_id: envelope.original_sender_id.map(PeerId),
340                        message_id,
341                        payload: T::from_envelope(envelope).unwrap(),
342                    });
343                    Some(
344                        async move {
345                            if let Err(error) = future.await {
346                                log::error!(
347                                    "error handling message {} {}: {:?}",
348                                    T::NAME,
349                                    message_id,
350                                    error
351                                );
352                            }
353                        }
354                        .boxed(),
355                    )
356                } else {
357                    None
358                }
359            }));
360    }
361}
362
363impl ForegroundRouter {
364    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
365    where
366        T: EnvelopedMessage,
367        Fut: 'static + Future<Output = Result<()>>,
368        F: 'static + Fn(TypedEnvelope<T>) -> Fut,
369    {
370        if !self.handler_types.insert(TypeId::of::<T>()) {
371            panic!("duplicate handler type");
372        }
373
374        self.message_handlers
375            .push(Box::new(move |envelope, connection_id| {
376                if envelope.as_ref().map_or(false, T::matches_envelope) {
377                    let envelope = Option::take(envelope).unwrap();
378                    let message_id = envelope.id;
379                    let future = handler(TypedEnvelope {
380                        sender_id: connection_id,
381                        original_sender_id: envelope.original_sender_id.map(PeerId),
382                        message_id,
383                        payload: T::from_envelope(envelope).unwrap(),
384                    });
385                    Some(
386                        async move {
387                            if let Err(error) = future.await {
388                                log::error!(
389                                    "error handling message {} {}: {:?}",
390                                    T::NAME,
391                                    message_id,
392                                    error
393                                );
394                            }
395                        }
396                        .boxed_local(),
397                    )
398                } else {
399                    None
400                }
401            }));
402    }
403}
404
405impl<W, R> IOHandler<W, R>
406where
407    W: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
408    R: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
409{
410    pub async fn run(mut self) -> Result<()> {
411        loop {
412            let read_message = self.reader.read_message().fuse();
413            futures::pin_mut!(read_message);
414            loop {
415                futures::select_biased! {
416                    incoming = read_message => match incoming {
417                        Ok(incoming) => {
418                            if self.incoming_tx.send(incoming).await.is_err() {
419                                return Ok(());
420                            }
421                            break;
422                        }
423                        Err(error) => {
424                            Err(error).context("received invalid RPC message")?;
425                        }
426                    },
427                    outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
428                        Some(outgoing) => {
429                            if let Err(result) = self.writer.write_message(&outgoing).await {
430                                Err(result).context("failed to write RPC message")?;
431                            }
432                        }
433                        None => return Ok(()),
434                    }
435                }
436            }
437        }
438    }
439
440    pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
441        let envelope = self.reader.read_message().await?;
442        let original_sender_id = envelope.original_sender_id;
443        let message_id = envelope.id;
444        let payload =
445            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
446        Ok(TypedEnvelope {
447            sender_id: self.connection_id,
448            original_sender_id: original_sender_id.map(PeerId),
449            message_id,
450            payload,
451        })
452    }
453}
454
455impl<T> Clone for Receipt<T> {
456    fn clone(&self) -> Self {
457        Self {
458            sender_id: self.sender_id,
459            message_id: self.message_id,
460            payload_type: PhantomData,
461        }
462    }
463}
464
465impl<T> Copy for Receipt<T> {}
466
467impl fmt::Display for ConnectionId {
468    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
469        self.0.fmt(f)
470    }
471}
472
473impl fmt::Display for PeerId {
474    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
475        self.0.fmt(f)
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use crate::test;
483
484    #[test]
485    fn test_request_response() {
486        smol::block_on(async move {
487            // create 2 clients connected to 1 server
488            let server = Peer::new();
489            let client1 = Peer::new();
490            let client2 = Peer::new();
491
492            let mut router = Router::new();
493            router.add_message_handler({
494                let server = server.clone();
495                move |envelope: TypedEnvelope<proto::Auth>| {
496                    let server = server.clone();
497                    async move {
498                        let receipt = envelope.receipt();
499                        let message = envelope.payload;
500                        server
501                            .respond(
502                                receipt,
503                                match message.user_id {
504                                    1 => {
505                                        assert_eq!(message.access_token, "access-token-1");
506                                        proto::AuthResponse {
507                                            credentials_valid: true,
508                                        }
509                                    }
510                                    2 => {
511                                        assert_eq!(message.access_token, "access-token-2");
512                                        proto::AuthResponse {
513                                            credentials_valid: false,
514                                        }
515                                    }
516                                    _ => {
517                                        panic!("unexpected user id {}", message.user_id);
518                                    }
519                                },
520                            )
521                            .await
522                    }
523                }
524            });
525
526            router.add_message_handler({
527                let server = server.clone();
528                move |envelope: TypedEnvelope<proto::OpenBuffer>| {
529                    let server = server.clone();
530                    async move {
531                        let receipt = envelope.receipt();
532                        let message = envelope.payload;
533                        server
534                            .respond(
535                                receipt,
536                                match message.path.as_str() {
537                                    "path/one" => {
538                                        assert_eq!(message.worktree_id, 1);
539                                        proto::OpenBufferResponse {
540                                            buffer: Some(proto::Buffer {
541                                                id: 101,
542                                                content: "path/one content".to_string(),
543                                                history: vec![],
544                                                selections: vec![],
545                                            }),
546                                        }
547                                    }
548                                    "path/two" => {
549                                        assert_eq!(message.worktree_id, 2);
550                                        proto::OpenBufferResponse {
551                                            buffer: Some(proto::Buffer {
552                                                id: 102,
553                                                content: "path/two content".to_string(),
554                                                history: vec![],
555                                                selections: vec![],
556                                            }),
557                                        }
558                                    }
559                                    _ => {
560                                        panic!("unexpected path {}", message.path);
561                                    }
562                                },
563                            )
564                            .await
565                    }
566                }
567            });
568            let router = Arc::new(router);
569
570            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
571            let (client1_conn_id, io_task1, msg_task1) = client1
572                .add_connection(client1_to_server_conn, router.clone())
573                .await;
574            let (_, io_task2, msg_task2) = server
575                .add_connection(server_to_client_1_conn, router.clone())
576                .await;
577
578            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
579            let (client2_conn_id, io_task3, msg_task3) = client2
580                .add_connection(client2_to_server_conn, router.clone())
581                .await;
582            let (_, io_task4, msg_task4) = server
583                .add_connection(server_to_client_2_conn, router.clone())
584                .await;
585
586            smol::spawn(io_task1.run()).detach();
587            smol::spawn(io_task2.run()).detach();
588            smol::spawn(io_task3.run()).detach();
589            smol::spawn(io_task4.run()).detach();
590            smol::spawn(msg_task1).detach();
591            smol::spawn(msg_task2).detach();
592            smol::spawn(msg_task3).detach();
593            smol::spawn(msg_task4).detach();
594
595            assert_eq!(
596                client1
597                    .request(
598                        client1_conn_id,
599                        proto::Auth {
600                            user_id: 1,
601                            access_token: "access-token-1".to_string(),
602                        },
603                    )
604                    .await
605                    .unwrap(),
606                proto::AuthResponse {
607                    credentials_valid: true,
608                }
609            );
610
611            assert_eq!(
612                client2
613                    .request(
614                        client2_conn_id,
615                        proto::Auth {
616                            user_id: 2,
617                            access_token: "access-token-2".to_string(),
618                        },
619                    )
620                    .await
621                    .unwrap(),
622                proto::AuthResponse {
623                    credentials_valid: false,
624                }
625            );
626
627            assert_eq!(
628                client1
629                    .request(
630                        client1_conn_id,
631                        proto::OpenBuffer {
632                            worktree_id: 1,
633                            path: "path/one".to_string(),
634                        },
635                    )
636                    .await
637                    .unwrap(),
638                proto::OpenBufferResponse {
639                    buffer: Some(proto::Buffer {
640                        id: 101,
641                        content: "path/one content".to_string(),
642                        history: vec![],
643                        selections: vec![],
644                    }),
645                }
646            );
647
648            assert_eq!(
649                client2
650                    .request(
651                        client2_conn_id,
652                        proto::OpenBuffer {
653                            worktree_id: 2,
654                            path: "path/two".to_string(),
655                        },
656                    )
657                    .await
658                    .unwrap(),
659                proto::OpenBufferResponse {
660                    buffer: Some(proto::Buffer {
661                        id: 102,
662                        content: "path/two content".to_string(),
663                        history: vec![],
664                        selections: vec![],
665                    }),
666                }
667            );
668
669            client1.disconnect(client1_conn_id).await;
670            client2.disconnect(client1_conn_id).await;
671        });
672    }
673
674    #[test]
675    fn test_disconnect() {
676        smol::block_on(async move {
677            let (client_conn, mut server_conn) = test::Channel::bidirectional();
678
679            let client = Peer::new();
680            let router = Arc::new(Router::new());
681            let (connection_id, io_handler, message_handler) =
682                client.add_connection(client_conn, router).await;
683
684            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
685            smol::spawn(async move {
686                io_handler.run().await.ok();
687                io_ended_tx.send(()).await.unwrap();
688            })
689            .detach();
690
691            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
692            smol::spawn(async move {
693                message_handler.await.ok();
694                messages_ended_tx.send(()).await.unwrap();
695            })
696            .detach();
697
698            client.disconnect(connection_id).await;
699
700            io_ended_rx.recv().await;
701            messages_ended_rx.recv().await;
702            assert!(
703                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
704                    .await
705                    .is_err()
706            );
707        });
708    }
709
710    #[test]
711    fn test_io_error() {
712        smol::block_on(async move {
713            let (client_conn, server_conn) = test::Channel::bidirectional();
714            drop(server_conn);
715
716            let client = Peer::new();
717            let router = Arc::new(Router::new());
718            let (connection_id, io_handler, message_handler) =
719                client.add_connection(client_conn, router).await;
720            smol::spawn(io_handler.run()).detach();
721            smol::spawn(message_handler).detach();
722
723            let err = client
724                .request(
725                    connection_id,
726                    proto::Auth {
727                        user_id: 42,
728                        access_token: "token".to_string(),
729                    },
730                )
731                .await
732                .unwrap_err();
733            assert_eq!(err.to_string(), "connection was closed");
734        });
735    }
736}