peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Context, Result};
  3use async_lock::{Mutex, RwLock};
  4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  5use futures::{
  6    future::{BoxFuture, LocalBoxFuture},
  7    FutureExt, StreamExt,
  8};
  9use postage::{
 10    mpsc,
 11    prelude::{Sink, Stream},
 12};
 13use std::{
 14    any::TypeId,
 15    collections::{HashMap, HashSet},
 16    fmt,
 17    future::Future,
 18    marker::PhantomData,
 19    sync::{
 20        atomic::{self, AtomicU32},
 21        Arc,
 22    },
 23};
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 26pub struct ConnectionId(pub u32);
 27
 28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 29pub struct PeerId(pub u32);
 30
 31type MessageHandler = Box<
 32    dyn Send
 33        + Sync
 34        + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
 35>;
 36
 37type ForegroundMessageHandler =
 38    Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
 39
 40pub struct Receipt<T> {
 41    sender_id: ConnectionId,
 42    message_id: u32,
 43    payload_type: PhantomData<T>,
 44}
 45
 46pub struct TypedEnvelope<T> {
 47    pub sender_id: ConnectionId,
 48    original_sender_id: Option<PeerId>,
 49    pub message_id: u32,
 50    pub payload: T,
 51}
 52
 53impl<T> TypedEnvelope<T> {
 54    pub fn original_sender_id(&self) -> Result<PeerId> {
 55        self.original_sender_id
 56            .ok_or_else(|| anyhow!("missing original_sender_id"))
 57    }
 58}
 59
 60impl<T: RequestMessage> TypedEnvelope<T> {
 61    pub fn receipt(&self) -> Receipt<T> {
 62        Receipt {
 63            sender_id: self.sender_id,
 64            message_id: self.message_id,
 65            payload_type: PhantomData,
 66        }
 67    }
 68}
 69
 70pub type Router = RouterInternal<MessageHandler>;
 71pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
 72pub struct RouterInternal<H> {
 73    message_handlers: Vec<H>,
 74    handler_types: HashSet<TypeId>,
 75}
 76
 77pub struct Peer {
 78    connections: RwLock<HashMap<ConnectionId, Connection>>,
 79    next_connection_id: AtomicU32,
 80}
 81
 82#[derive(Clone)]
 83struct Connection {
 84    outgoing_tx: mpsc::Sender<proto::Envelope>,
 85    next_message_id: Arc<AtomicU32>,
 86    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 87}
 88
 89impl Peer {
 90    pub fn new() -> Arc<Self> {
 91        Arc::new(Self {
 92            connections: Default::default(),
 93            next_connection_id: Default::default(),
 94        })
 95    }
 96
 97    pub async fn add_connection<Conn, H, Fut>(
 98        self: &Arc<Self>,
 99        conn: Conn,
100        router: Arc<RouterInternal<H>>,
101    ) -> (
102        ConnectionId,
103        impl Future<Output = anyhow::Result<()>> + Send,
104        impl Future<Output = ()>,
105    )
106    where
107        H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
108        Fut: Future<Output = ()>,
109        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
110            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
111            + Send
112            + Unpin,
113    {
114        let (tx, rx) = conn.split();
115        let connection_id = ConnectionId(
116            self.next_connection_id
117                .fetch_add(1, atomic::Ordering::SeqCst),
118        );
119        let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64);
120        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
121        let connection = Connection {
122            outgoing_tx,
123            next_message_id: Default::default(),
124            response_channels: Default::default(),
125        };
126        let mut writer = MessageStream::new(tx);
127        let mut reader = MessageStream::new(rx);
128
129        let handle_io = async move {
130            loop {
131                let read_message = reader.read_message().fuse();
132                futures::pin_mut!(read_message);
133                loop {
134                    futures::select_biased! {
135                        incoming = read_message => match incoming {
136                            Ok(incoming) => {
137                                if incoming_tx.send(incoming).await.is_err() {
138                                    return Ok(());
139                                }
140                                break;
141                            }
142                            Err(error) => {
143                                Err(error).context("received invalid RPC message")?;
144                            }
145                        },
146                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
147                            Some(outgoing) => {
148                                if let Err(result) = writer.write_message(&outgoing).await {
149                                    Err(result).context("failed to write RPC message")?;
150                                }
151                            }
152                            None => return Ok(()),
153                        }
154                    }
155                }
156            }
157        };
158
159        let response_channels = connection.response_channels.clone();
160        let handle_messages = async move {
161            while let Some(message) = incoming_rx.recv().await {
162                if let Some(responding_to) = message.responding_to {
163                    let channel = response_channels.lock().await.remove(&responding_to);
164                    if let Some(mut tx) = channel {
165                        tx.send(message).await.ok();
166                    } else {
167                        log::warn!("received RPC response to unknown request {}", responding_to);
168                    }
169                } else {
170                    router.handle(connection_id, message).await;
171                }
172            }
173            response_channels.lock().await.clear();
174        };
175
176        self.connections
177            .write()
178            .await
179            .insert(connection_id, connection);
180
181        (connection_id, handle_io, handle_messages)
182    }
183
184    pub async fn disconnect(&self, connection_id: ConnectionId) {
185        self.connections.write().await.remove(&connection_id);
186    }
187
188    pub async fn reset(&self) {
189        self.connections.write().await.clear();
190    }
191
192    pub fn request<T: RequestMessage>(
193        self: &Arc<Self>,
194        receiver_id: ConnectionId,
195        request: T,
196    ) -> impl Future<Output = Result<T::Response>> {
197        self.request_internal(None, receiver_id, request)
198    }
199
200    pub fn forward_request<T: RequestMessage>(
201        self: &Arc<Self>,
202        sender_id: ConnectionId,
203        receiver_id: ConnectionId,
204        request: T,
205    ) -> impl Future<Output = Result<T::Response>> {
206        self.request_internal(Some(sender_id), receiver_id, request)
207    }
208
209    pub fn request_internal<T: RequestMessage>(
210        self: &Arc<Self>,
211        original_sender_id: Option<ConnectionId>,
212        receiver_id: ConnectionId,
213        request: T,
214    ) -> impl Future<Output = Result<T::Response>> {
215        let this = self.clone();
216        let (tx, mut rx) = mpsc::channel(1);
217        async move {
218            let mut connection = this.connection(receiver_id).await?;
219            let message_id = connection
220                .next_message_id
221                .fetch_add(1, atomic::Ordering::SeqCst);
222            connection
223                .response_channels
224                .lock()
225                .await
226                .insert(message_id, tx);
227            connection
228                .outgoing_tx
229                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
230                .await?;
231            let response = rx
232                .recv()
233                .await
234                .ok_or_else(|| anyhow!("connection was closed"))?;
235            T::Response::from_envelope(response)
236                .ok_or_else(|| anyhow!("received response of the wrong type"))
237        }
238    }
239
240    pub fn send<T: EnvelopedMessage>(
241        self: &Arc<Self>,
242        receiver_id: ConnectionId,
243        message: T,
244    ) -> impl Future<Output = Result<()>> {
245        let this = self.clone();
246        async move {
247            let mut connection = this.connection(receiver_id).await?;
248            let message_id = connection
249                .next_message_id
250                .fetch_add(1, atomic::Ordering::SeqCst);
251            connection
252                .outgoing_tx
253                .send(message.into_envelope(message_id, None, None))
254                .await?;
255            Ok(())
256        }
257    }
258
259    pub fn forward_send<T: EnvelopedMessage>(
260        self: &Arc<Self>,
261        sender_id: ConnectionId,
262        receiver_id: ConnectionId,
263        message: T,
264    ) -> impl Future<Output = Result<()>> {
265        let this = self.clone();
266        async move {
267            let mut connection = this.connection(receiver_id).await?;
268            let message_id = connection
269                .next_message_id
270                .fetch_add(1, atomic::Ordering::SeqCst);
271            connection
272                .outgoing_tx
273                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
274                .await?;
275            Ok(())
276        }
277    }
278
279    pub fn respond<T: RequestMessage>(
280        self: &Arc<Self>,
281        receipt: Receipt<T>,
282        response: T::Response,
283    ) -> impl Future<Output = Result<()>> {
284        let this = self.clone();
285        async move {
286            let mut connection = this.connection(receipt.sender_id).await?;
287            let message_id = connection
288                .next_message_id
289                .fetch_add(1, atomic::Ordering::SeqCst);
290            connection
291                .outgoing_tx
292                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
293                .await?;
294            Ok(())
295        }
296    }
297
298    fn connection(
299        self: &Arc<Self>,
300        connection_id: ConnectionId,
301    ) -> impl Future<Output = Result<Connection>> {
302        let this = self.clone();
303        async move {
304            let connections = this.connections.read().await;
305            let connection = connections
306                .get(&connection_id)
307                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
308            Ok(connection.clone())
309        }
310    }
311}
312
313impl<H, Fut> RouterInternal<H>
314where
315    H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
316    Fut: Future<Output = ()>,
317{
318    pub fn new() -> Self {
319        Self {
320            message_handlers: Default::default(),
321            handler_types: Default::default(),
322        }
323    }
324
325    async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
326        let mut envelope = Some(message);
327        for handler in self.message_handlers.iter() {
328            if let Some(future) = handler(&mut envelope, connection_id) {
329                future.await;
330                return;
331            }
332        }
333        log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
334    }
335}
336
337impl Router {
338    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
339    where
340        T: EnvelopedMessage,
341        Fut: 'static + Send + Future<Output = Result<()>>,
342        F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
343    {
344        if !self.handler_types.insert(TypeId::of::<T>()) {
345            panic!("duplicate handler type");
346        }
347
348        self.message_handlers
349            .push(Box::new(move |envelope, connection_id| {
350                if envelope.as_ref().map_or(false, T::matches_envelope) {
351                    let envelope = Option::take(envelope).unwrap();
352                    let message_id = envelope.id;
353                    let future = handler(TypedEnvelope {
354                        sender_id: connection_id,
355                        original_sender_id: envelope.original_sender_id.map(PeerId),
356                        message_id,
357                        payload: T::from_envelope(envelope).unwrap(),
358                    });
359                    Some(
360                        async move {
361                            if let Err(error) = future.await {
362                                log::error!(
363                                    "error handling message {} {}: {:?}",
364                                    T::NAME,
365                                    message_id,
366                                    error
367                                );
368                            }
369                        }
370                        .boxed(),
371                    )
372                } else {
373                    None
374                }
375            }));
376    }
377}
378
379impl ForegroundRouter {
380    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
381    where
382        T: EnvelopedMessage,
383        Fut: 'static + Future<Output = Result<()>>,
384        F: 'static + Fn(TypedEnvelope<T>) -> Fut,
385    {
386        if !self.handler_types.insert(TypeId::of::<T>()) {
387            panic!("duplicate handler type");
388        }
389
390        self.message_handlers
391            .push(Box::new(move |envelope, connection_id| {
392                if envelope.as_ref().map_or(false, T::matches_envelope) {
393                    let envelope = Option::take(envelope).unwrap();
394                    let message_id = envelope.id;
395                    let future = handler(TypedEnvelope {
396                        sender_id: connection_id,
397                        original_sender_id: envelope.original_sender_id.map(PeerId),
398                        message_id,
399                        payload: T::from_envelope(envelope).unwrap(),
400                    });
401                    Some(
402                        async move {
403                            if let Err(error) = future.await {
404                                log::error!(
405                                    "error handling message {} {}: {:?}",
406                                    T::NAME,
407                                    message_id,
408                                    error
409                                );
410                            }
411                        }
412                        .boxed_local(),
413                    )
414                } else {
415                    None
416                }
417            }));
418    }
419}
420
421impl<T> Clone for Receipt<T> {
422    fn clone(&self) -> Self {
423        Self {
424            sender_id: self.sender_id,
425            message_id: self.message_id,
426            payload_type: PhantomData,
427        }
428    }
429}
430
431impl<T> Copy for Receipt<T> {}
432
433impl fmt::Display for ConnectionId {
434    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435        self.0.fmt(f)
436    }
437}
438
439impl fmt::Display for PeerId {
440    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441        self.0.fmt(f)
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use crate::test;
449
450    #[test]
451    fn test_request_response() {
452        smol::block_on(async move {
453            // create 2 clients connected to 1 server
454            let server = Peer::new();
455            let client1 = Peer::new();
456            let client2 = Peer::new();
457
458            let mut router = Router::new();
459            router.add_message_handler({
460                let server = server.clone();
461                move |envelope: TypedEnvelope<proto::Auth>| {
462                    let server = server.clone();
463                    async move {
464                        let receipt = envelope.receipt();
465                        let message = envelope.payload;
466                        server
467                            .respond(
468                                receipt,
469                                match message.user_id {
470                                    1 => {
471                                        assert_eq!(message.access_token, "access-token-1");
472                                        proto::AuthResponse {
473                                            credentials_valid: true,
474                                        }
475                                    }
476                                    2 => {
477                                        assert_eq!(message.access_token, "access-token-2");
478                                        proto::AuthResponse {
479                                            credentials_valid: false,
480                                        }
481                                    }
482                                    _ => {
483                                        panic!("unexpected user id {}", message.user_id);
484                                    }
485                                },
486                            )
487                            .await
488                    }
489                }
490            });
491
492            router.add_message_handler({
493                let server = server.clone();
494                move |envelope: TypedEnvelope<proto::OpenBuffer>| {
495                    let server = server.clone();
496                    async move {
497                        let receipt = envelope.receipt();
498                        let message = envelope.payload;
499                        server
500                            .respond(
501                                receipt,
502                                match message.path.as_str() {
503                                    "path/one" => {
504                                        assert_eq!(message.worktree_id, 1);
505                                        proto::OpenBufferResponse {
506                                            buffer: Some(proto::Buffer {
507                                                id: 101,
508                                                content: "path/one content".to_string(),
509                                                history: vec![],
510                                                selections: vec![],
511                                            }),
512                                        }
513                                    }
514                                    "path/two" => {
515                                        assert_eq!(message.worktree_id, 2);
516                                        proto::OpenBufferResponse {
517                                            buffer: Some(proto::Buffer {
518                                                id: 102,
519                                                content: "path/two content".to_string(),
520                                                history: vec![],
521                                                selections: vec![],
522                                            }),
523                                        }
524                                    }
525                                    _ => {
526                                        panic!("unexpected path {}", message.path);
527                                    }
528                                },
529                            )
530                            .await
531                    }
532                }
533            });
534            let router = Arc::new(router);
535
536            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
537            let (client1_conn_id, io_task1, msg_task1) = client1
538                .add_connection(client1_to_server_conn, router.clone())
539                .await;
540            let (_, io_task2, msg_task2) = server
541                .add_connection(server_to_client_1_conn, router.clone())
542                .await;
543
544            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
545            let (client2_conn_id, io_task3, msg_task3) = client2
546                .add_connection(client2_to_server_conn, router.clone())
547                .await;
548            let (_, io_task4, msg_task4) = server
549                .add_connection(server_to_client_2_conn, router.clone())
550                .await;
551
552            smol::spawn(io_task1).detach();
553            smol::spawn(io_task2).detach();
554            smol::spawn(io_task3).detach();
555            smol::spawn(io_task4).detach();
556            smol::spawn(msg_task1).detach();
557            smol::spawn(msg_task2).detach();
558            smol::spawn(msg_task3).detach();
559            smol::spawn(msg_task4).detach();
560
561            assert_eq!(
562                client1
563                    .request(
564                        client1_conn_id,
565                        proto::Auth {
566                            user_id: 1,
567                            access_token: "access-token-1".to_string(),
568                        },
569                    )
570                    .await
571                    .unwrap(),
572                proto::AuthResponse {
573                    credentials_valid: true,
574                }
575            );
576
577            assert_eq!(
578                client2
579                    .request(
580                        client2_conn_id,
581                        proto::Auth {
582                            user_id: 2,
583                            access_token: "access-token-2".to_string(),
584                        },
585                    )
586                    .await
587                    .unwrap(),
588                proto::AuthResponse {
589                    credentials_valid: false,
590                }
591            );
592
593            assert_eq!(
594                client1
595                    .request(
596                        client1_conn_id,
597                        proto::OpenBuffer {
598                            worktree_id: 1,
599                            path: "path/one".to_string(),
600                        },
601                    )
602                    .await
603                    .unwrap(),
604                proto::OpenBufferResponse {
605                    buffer: Some(proto::Buffer {
606                        id: 101,
607                        content: "path/one content".to_string(),
608                        history: vec![],
609                        selections: vec![],
610                    }),
611                }
612            );
613
614            assert_eq!(
615                client2
616                    .request(
617                        client2_conn_id,
618                        proto::OpenBuffer {
619                            worktree_id: 2,
620                            path: "path/two".to_string(),
621                        },
622                    )
623                    .await
624                    .unwrap(),
625                proto::OpenBufferResponse {
626                    buffer: Some(proto::Buffer {
627                        id: 102,
628                        content: "path/two content".to_string(),
629                        history: vec![],
630                        selections: vec![],
631                    }),
632                }
633            );
634
635            client1.disconnect(client1_conn_id).await;
636            client2.disconnect(client1_conn_id).await;
637        });
638    }
639
640    #[test]
641    fn test_disconnect() {
642        smol::block_on(async move {
643            let (client_conn, mut server_conn) = test::Channel::bidirectional();
644
645            let client = Peer::new();
646            let router = Arc::new(Router::new());
647            let (connection_id, io_handler, message_handler) =
648                client.add_connection(client_conn, router).await;
649
650            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
651            smol::spawn(async move {
652                io_handler.await.ok();
653                io_ended_tx.send(()).await.unwrap();
654            })
655            .detach();
656
657            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
658            smol::spawn(async move {
659                message_handler.await;
660                messages_ended_tx.send(()).await.unwrap();
661            })
662            .detach();
663
664            client.disconnect(connection_id).await;
665
666            io_ended_rx.recv().await;
667            messages_ended_rx.recv().await;
668            assert!(
669                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
670                    .await
671                    .is_err()
672            );
673        });
674    }
675
676    #[test]
677    fn test_io_error() {
678        smol::block_on(async move {
679            let (client_conn, server_conn) = test::Channel::bidirectional();
680            drop(server_conn);
681
682            let client = Peer::new();
683            let router = Arc::new(Router::new());
684            let (connection_id, io_handler, message_handler) =
685                client.add_connection(client_conn, router).await;
686            smol::spawn(io_handler).detach();
687            smol::spawn(message_handler).detach();
688
689            let err = client
690                .request(
691                    connection_id,
692                    proto::Auth {
693                        user_id: 42,
694                        access_token: "token".to_string(),
695                    },
696                )
697                .await
698                .unwrap_err();
699            assert_eq!(err.to_string(), "connection was closed");
700        });
701    }
702}