peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Context, Result};
  3use async_lock::{Mutex, RwLock};
  4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  5use futures::{
  6    future::{BoxFuture, LocalBoxFuture},
  7    FutureExt, StreamExt,
  8};
  9use postage::{
 10    mpsc,
 11    prelude::{Sink, Stream},
 12};
 13use std::{
 14    any::TypeId,
 15    collections::{HashMap, HashSet},
 16    fmt,
 17    future::Future,
 18    marker::PhantomData,
 19    sync::{
 20        atomic::{self, AtomicU32},
 21        Arc,
 22    },
 23};
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 26pub struct ConnectionId(pub u32);
 27
 28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 29pub struct PeerId(pub u32);
 30
 31type MessageHandler = Box<
 32    dyn Send
 33        + Sync
 34        + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
 35>;
 36
 37type ForegroundMessageHandler =
 38    Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
 39
 40pub struct Receipt<T> {
 41    sender_id: ConnectionId,
 42    message_id: u32,
 43    payload_type: PhantomData<T>,
 44}
 45
 46pub struct TypedEnvelope<T> {
 47    pub sender_id: ConnectionId,
 48    original_sender_id: Option<PeerId>,
 49    pub message_id: u32,
 50    pub payload: T,
 51}
 52
 53impl<T> TypedEnvelope<T> {
 54    pub fn original_sender_id(&self) -> Result<PeerId> {
 55        self.original_sender_id
 56            .ok_or_else(|| anyhow!("missing original_sender_id"))
 57    }
 58}
 59
 60impl<T: RequestMessage> TypedEnvelope<T> {
 61    pub fn receipt(&self) -> Receipt<T> {
 62        Receipt {
 63            sender_id: self.sender_id,
 64            message_id: self.message_id,
 65            payload_type: PhantomData,
 66        }
 67    }
 68}
 69
 70pub type Router = RouterInternal<MessageHandler>;
 71pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
 72pub struct RouterInternal<H> {
 73    message_handlers: Vec<H>,
 74    handler_types: HashSet<TypeId>,
 75}
 76
 77pub struct Peer {
 78    connections: RwLock<HashMap<ConnectionId, Connection>>,
 79    next_connection_id: AtomicU32,
 80}
 81
 82#[derive(Clone)]
 83struct Connection {
 84    outgoing_tx: mpsc::Sender<proto::Envelope>,
 85    next_message_id: Arc<AtomicU32>,
 86    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 87}
 88
 89impl Peer {
 90    pub fn new() -> Arc<Self> {
 91        Arc::new(Self {
 92            connections: Default::default(),
 93            next_connection_id: Default::default(),
 94        })
 95    }
 96
 97    pub async fn add_connection<Conn, H, Fut>(
 98        self: &Arc<Self>,
 99        conn: Conn,
100        router: Arc<RouterInternal<H>>,
101    ) -> (
102        ConnectionId,
103        impl Future<Output = anyhow::Result<()>> + Send,
104        impl Future<Output = ()>,
105    )
106    where
107        H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
108        Fut: Future<Output = ()>,
109        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
110            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
111            + Send
112            + Unpin,
113    {
114        let (tx, rx) = conn.split();
115        let connection_id = ConnectionId(
116            self.next_connection_id
117                .fetch_add(1, atomic::Ordering::SeqCst),
118        );
119        let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64);
120        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
121        let connection = Connection {
122            outgoing_tx,
123            next_message_id: Default::default(),
124            response_channels: Default::default(),
125        };
126        let mut writer = MessageStream::new(tx);
127        let mut reader = MessageStream::new(rx);
128
129        let handle_io = async move {
130            loop {
131                let read_message = reader.read_message().fuse();
132                futures::pin_mut!(read_message);
133                loop {
134                    futures::select_biased! {
135                        incoming = read_message => match incoming {
136                            Ok(incoming) => {
137                                if incoming_tx.send(incoming).await.is_err() {
138                                    return Ok(());
139                                }
140                                break;
141                            }
142                            Err(error) => {
143                                Err(error).context("received invalid RPC message")?;
144                            }
145                        },
146                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
147                            Some(outgoing) => {
148                                if let Err(result) = writer.write_message(&outgoing).await {
149                                    Err(result).context("failed to write RPC message")?;
150                                }
151                            }
152                            None => return Ok(()),
153                        }
154                    }
155                }
156            }
157        };
158
159        let response_channels = connection.response_channels.clone();
160        let handle_messages = async move {
161            while let Some(message) = incoming_rx.recv().await {
162                if let Some(responding_to) = message.responding_to {
163                    let channel = response_channels.lock().await.remove(&responding_to);
164                    if let Some(mut tx) = channel {
165                        tx.send(message).await.ok();
166                    } else {
167                        log::warn!("received RPC response to unknown request {}", responding_to);
168                    }
169                } else {
170                    router.handle(connection_id, message).await;
171                }
172            }
173            response_channels.lock().await.clear();
174        };
175
176        self.connections
177            .write()
178            .await
179            .insert(connection_id, connection);
180
181        (connection_id, handle_io, handle_messages)
182    }
183
184    pub async fn disconnect(&self, connection_id: ConnectionId) {
185        self.connections.write().await.remove(&connection_id);
186    }
187
188    pub async fn reset(&self) {
189        self.connections.write().await.clear();
190    }
191
192    pub fn request<T: RequestMessage>(
193        self: &Arc<Self>,
194        receiver_id: ConnectionId,
195        request: T,
196    ) -> impl Future<Output = Result<T::Response>> {
197        self.request_internal(None, receiver_id, request)
198    }
199
200    pub fn forward_request<T: RequestMessage>(
201        self: &Arc<Self>,
202        sender_id: ConnectionId,
203        receiver_id: ConnectionId,
204        request: T,
205    ) -> impl Future<Output = Result<T::Response>> {
206        self.request_internal(Some(sender_id), receiver_id, request)
207    }
208
209    pub fn request_internal<T: RequestMessage>(
210        self: &Arc<Self>,
211        original_sender_id: Option<ConnectionId>,
212        receiver_id: ConnectionId,
213        request: T,
214    ) -> impl Future<Output = Result<T::Response>> {
215        let this = self.clone();
216        let (tx, mut rx) = mpsc::channel(1);
217        async move {
218            let mut connection = this.connection(receiver_id).await?;
219            let message_id = connection
220                .next_message_id
221                .fetch_add(1, atomic::Ordering::SeqCst);
222            connection
223                .response_channels
224                .lock()
225                .await
226                .insert(message_id, tx);
227            connection
228                .outgoing_tx
229                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
230                .await
231                .map_err(|_| anyhow!("connection was closed"))?;
232            let response = rx
233                .recv()
234                .await
235                .ok_or_else(|| anyhow!("connection was closed"))?;
236            T::Response::from_envelope(response)
237                .ok_or_else(|| anyhow!("received response of the wrong type"))
238        }
239    }
240
241    pub fn send<T: EnvelopedMessage>(
242        self: &Arc<Self>,
243        receiver_id: ConnectionId,
244        message: T,
245    ) -> impl Future<Output = Result<()>> {
246        let this = self.clone();
247        async move {
248            let mut connection = this.connection(receiver_id).await?;
249            let message_id = connection
250                .next_message_id
251                .fetch_add(1, atomic::Ordering::SeqCst);
252            connection
253                .outgoing_tx
254                .send(message.into_envelope(message_id, None, None))
255                .await?;
256            Ok(())
257        }
258    }
259
260    pub fn forward_send<T: EnvelopedMessage>(
261        self: &Arc<Self>,
262        sender_id: ConnectionId,
263        receiver_id: ConnectionId,
264        message: T,
265    ) -> impl Future<Output = Result<()>> {
266        let this = self.clone();
267        async move {
268            let mut connection = this.connection(receiver_id).await?;
269            let message_id = connection
270                .next_message_id
271                .fetch_add(1, atomic::Ordering::SeqCst);
272            connection
273                .outgoing_tx
274                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
275                .await?;
276            Ok(())
277        }
278    }
279
280    pub fn respond<T: RequestMessage>(
281        self: &Arc<Self>,
282        receipt: Receipt<T>,
283        response: T::Response,
284    ) -> impl Future<Output = Result<()>> {
285        let this = self.clone();
286        async move {
287            let mut connection = this.connection(receipt.sender_id).await?;
288            let message_id = connection
289                .next_message_id
290                .fetch_add(1, atomic::Ordering::SeqCst);
291            connection
292                .outgoing_tx
293                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
294                .await?;
295            Ok(())
296        }
297    }
298
299    fn connection(
300        self: &Arc<Self>,
301        connection_id: ConnectionId,
302    ) -> impl Future<Output = Result<Connection>> {
303        let this = self.clone();
304        async move {
305            let connections = this.connections.read().await;
306            let connection = connections
307                .get(&connection_id)
308                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
309            Ok(connection.clone())
310        }
311    }
312}
313
314impl<H, Fut> RouterInternal<H>
315where
316    H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
317    Fut: Future<Output = ()>,
318{
319    pub fn new() -> Self {
320        Self {
321            message_handlers: Default::default(),
322            handler_types: Default::default(),
323        }
324    }
325
326    async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
327        let mut envelope = Some(message);
328        for handler in self.message_handlers.iter() {
329            if let Some(future) = handler(&mut envelope, connection_id) {
330                future.await;
331                return;
332            }
333        }
334        log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
335    }
336}
337
338impl Router {
339    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
340    where
341        T: EnvelopedMessage,
342        Fut: 'static + Send + Future<Output = Result<()>>,
343        F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
344    {
345        if !self.handler_types.insert(TypeId::of::<T>()) {
346            panic!("duplicate handler type");
347        }
348
349        self.message_handlers
350            .push(Box::new(move |envelope, connection_id| {
351                if envelope.as_ref().map_or(false, T::matches_envelope) {
352                    let envelope = Option::take(envelope).unwrap();
353                    let message_id = envelope.id;
354                    let future = handler(TypedEnvelope {
355                        sender_id: connection_id,
356                        original_sender_id: envelope.original_sender_id.map(PeerId),
357                        message_id,
358                        payload: T::from_envelope(envelope).unwrap(),
359                    });
360                    Some(
361                        async move {
362                            if let Err(error) = future.await {
363                                log::error!(
364                                    "error handling message {} {}: {:?}",
365                                    T::NAME,
366                                    message_id,
367                                    error
368                                );
369                            }
370                        }
371                        .boxed(),
372                    )
373                } else {
374                    None
375                }
376            }));
377    }
378}
379
380impl ForegroundRouter {
381    pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
382    where
383        T: EnvelopedMessage,
384        Fut: 'static + Future<Output = Result<()>>,
385        F: 'static + Fn(TypedEnvelope<T>) -> Fut,
386    {
387        if !self.handler_types.insert(TypeId::of::<T>()) {
388            panic!("duplicate handler type");
389        }
390
391        self.message_handlers
392            .push(Box::new(move |envelope, connection_id| {
393                if envelope.as_ref().map_or(false, T::matches_envelope) {
394                    let envelope = Option::take(envelope).unwrap();
395                    let message_id = envelope.id;
396                    let future = handler(TypedEnvelope {
397                        sender_id: connection_id,
398                        original_sender_id: envelope.original_sender_id.map(PeerId),
399                        message_id,
400                        payload: T::from_envelope(envelope).unwrap(),
401                    });
402                    Some(
403                        async move {
404                            if let Err(error) = future.await {
405                                log::error!(
406                                    "error handling message {} {}: {:?}",
407                                    T::NAME,
408                                    message_id,
409                                    error
410                                );
411                            }
412                        }
413                        .boxed_local(),
414                    )
415                } else {
416                    None
417                }
418            }));
419    }
420}
421
422impl<T> Clone for Receipt<T> {
423    fn clone(&self) -> Self {
424        Self {
425            sender_id: self.sender_id,
426            message_id: self.message_id,
427            payload_type: PhantomData,
428        }
429    }
430}
431
432impl<T> Copy for Receipt<T> {}
433
434impl fmt::Display for ConnectionId {
435    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436        self.0.fmt(f)
437    }
438}
439
440impl fmt::Display for PeerId {
441    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
442        self.0.fmt(f)
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449    use crate::test;
450
451    #[test]
452    fn test_request_response() {
453        smol::block_on(async move {
454            // create 2 clients connected to 1 server
455            let server = Peer::new();
456            let client1 = Peer::new();
457            let client2 = Peer::new();
458
459            let mut router = Router::new();
460            router.add_message_handler({
461                let server = server.clone();
462                move |envelope: TypedEnvelope<proto::Auth>| {
463                    let server = server.clone();
464                    async move {
465                        let receipt = envelope.receipt();
466                        let message = envelope.payload;
467                        server
468                            .respond(
469                                receipt,
470                                match message.user_id {
471                                    1 => {
472                                        assert_eq!(message.access_token, "access-token-1");
473                                        proto::AuthResponse {
474                                            credentials_valid: true,
475                                        }
476                                    }
477                                    2 => {
478                                        assert_eq!(message.access_token, "access-token-2");
479                                        proto::AuthResponse {
480                                            credentials_valid: false,
481                                        }
482                                    }
483                                    _ => {
484                                        panic!("unexpected user id {}", message.user_id);
485                                    }
486                                },
487                            )
488                            .await
489                    }
490                }
491            });
492
493            router.add_message_handler({
494                let server = server.clone();
495                move |envelope: TypedEnvelope<proto::OpenBuffer>| {
496                    let server = server.clone();
497                    async move {
498                        let receipt = envelope.receipt();
499                        let message = envelope.payload;
500                        server
501                            .respond(
502                                receipt,
503                                match message.path.as_str() {
504                                    "path/one" => {
505                                        assert_eq!(message.worktree_id, 1);
506                                        proto::OpenBufferResponse {
507                                            buffer: Some(proto::Buffer {
508                                                id: 101,
509                                                content: "path/one content".to_string(),
510                                                history: vec![],
511                                                selections: vec![],
512                                            }),
513                                        }
514                                    }
515                                    "path/two" => {
516                                        assert_eq!(message.worktree_id, 2);
517                                        proto::OpenBufferResponse {
518                                            buffer: Some(proto::Buffer {
519                                                id: 102,
520                                                content: "path/two content".to_string(),
521                                                history: vec![],
522                                                selections: vec![],
523                                            }),
524                                        }
525                                    }
526                                    _ => {
527                                        panic!("unexpected path {}", message.path);
528                                    }
529                                },
530                            )
531                            .await
532                    }
533                }
534            });
535            let router = Arc::new(router);
536
537            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
538            let (client1_conn_id, io_task1, msg_task1) = client1
539                .add_connection(client1_to_server_conn, router.clone())
540                .await;
541            let (_, io_task2, msg_task2) = server
542                .add_connection(server_to_client_1_conn, router.clone())
543                .await;
544
545            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
546            let (client2_conn_id, io_task3, msg_task3) = client2
547                .add_connection(client2_to_server_conn, router.clone())
548                .await;
549            let (_, io_task4, msg_task4) = server
550                .add_connection(server_to_client_2_conn, router.clone())
551                .await;
552
553            smol::spawn(io_task1).detach();
554            smol::spawn(io_task2).detach();
555            smol::spawn(io_task3).detach();
556            smol::spawn(io_task4).detach();
557            smol::spawn(msg_task1).detach();
558            smol::spawn(msg_task2).detach();
559            smol::spawn(msg_task3).detach();
560            smol::spawn(msg_task4).detach();
561
562            assert_eq!(
563                client1
564                    .request(
565                        client1_conn_id,
566                        proto::Auth {
567                            user_id: 1,
568                            access_token: "access-token-1".to_string(),
569                        },
570                    )
571                    .await
572                    .unwrap(),
573                proto::AuthResponse {
574                    credentials_valid: true,
575                }
576            );
577
578            assert_eq!(
579                client2
580                    .request(
581                        client2_conn_id,
582                        proto::Auth {
583                            user_id: 2,
584                            access_token: "access-token-2".to_string(),
585                        },
586                    )
587                    .await
588                    .unwrap(),
589                proto::AuthResponse {
590                    credentials_valid: false,
591                }
592            );
593
594            assert_eq!(
595                client1
596                    .request(
597                        client1_conn_id,
598                        proto::OpenBuffer {
599                            worktree_id: 1,
600                            path: "path/one".to_string(),
601                        },
602                    )
603                    .await
604                    .unwrap(),
605                proto::OpenBufferResponse {
606                    buffer: Some(proto::Buffer {
607                        id: 101,
608                        content: "path/one content".to_string(),
609                        history: vec![],
610                        selections: vec![],
611                    }),
612                }
613            );
614
615            assert_eq!(
616                client2
617                    .request(
618                        client2_conn_id,
619                        proto::OpenBuffer {
620                            worktree_id: 2,
621                            path: "path/two".to_string(),
622                        },
623                    )
624                    .await
625                    .unwrap(),
626                proto::OpenBufferResponse {
627                    buffer: Some(proto::Buffer {
628                        id: 102,
629                        content: "path/two content".to_string(),
630                        history: vec![],
631                        selections: vec![],
632                    }),
633                }
634            );
635
636            client1.disconnect(client1_conn_id).await;
637            client2.disconnect(client1_conn_id).await;
638        });
639    }
640
641    #[test]
642    fn test_disconnect() {
643        smol::block_on(async move {
644            let (client_conn, mut server_conn) = test::Channel::bidirectional();
645
646            let client = Peer::new();
647            let router = Arc::new(Router::new());
648            let (connection_id, io_handler, message_handler) =
649                client.add_connection(client_conn, router).await;
650
651            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
652            smol::spawn(async move {
653                io_handler.await.ok();
654                io_ended_tx.send(()).await.unwrap();
655            })
656            .detach();
657
658            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
659            smol::spawn(async move {
660                message_handler.await;
661                messages_ended_tx.send(()).await.unwrap();
662            })
663            .detach();
664
665            client.disconnect(connection_id).await;
666
667            io_ended_rx.recv().await;
668            messages_ended_rx.recv().await;
669            assert!(
670                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
671                    .await
672                    .is_err()
673            );
674        });
675    }
676
677    #[test]
678    fn test_io_error() {
679        smol::block_on(async move {
680            let (client_conn, server_conn) = test::Channel::bidirectional();
681            drop(server_conn);
682
683            let client = Peer::new();
684            let router = Arc::new(Router::new());
685            let (connection_id, io_handler, message_handler) =
686                client.add_connection(client_conn, router).await;
687            smol::spawn(io_handler).detach();
688            smol::spawn(message_handler).detach();
689
690            let err = client
691                .request(
692                    connection_id,
693                    proto::Auth {
694                        user_id: 42,
695                        access_token: "token".to_string(),
696                    },
697                )
698                .await
699                .unwrap_err();
700            assert_eq!(err.to_string(), "connection was closed");
701        });
702    }
703}