peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Context, Result};
  3use async_lock::{Mutex, RwLock};
  4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  5use futures::{
  6    future::BoxFuture,
  7    stream::{SplitSink, SplitStream},
  8    FutureExt, StreamExt,
  9};
 10use postage::{
 11    mpsc,
 12    prelude::{Sink, Stream},
 13};
 14use std::{
 15    any::TypeId,
 16    collections::{HashMap, HashSet},
 17    fmt,
 18    future::Future,
 19    marker::PhantomData,
 20    sync::{
 21        atomic::{self, AtomicU32},
 22        Arc,
 23    },
 24};
 25
 26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 27pub struct ConnectionId(pub u32);
 28
 29#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 30pub struct PeerId(pub u32);
 31
 32type MessageHandler = Box<
 33    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 34>;
 35
 36pub struct Receipt<T> {
 37    sender_id: ConnectionId,
 38    message_id: u32,
 39    payload_type: PhantomData<T>,
 40}
 41
 42pub struct TypedEnvelope<T> {
 43    pub sender_id: ConnectionId,
 44    original_sender_id: Option<PeerId>,
 45    pub message_id: u32,
 46    pub payload: T,
 47}
 48
 49impl<T> TypedEnvelope<T> {
 50    pub fn original_sender_id(&self) -> Result<PeerId> {
 51        self.original_sender_id
 52            .ok_or_else(|| anyhow!("missing original_sender_id"))
 53    }
 54}
 55
 56impl<T: RequestMessage> TypedEnvelope<T> {
 57    pub fn receipt(&self) -> Receipt<T> {
 58        Receipt {
 59            sender_id: self.sender_id,
 60            message_id: self.message_id,
 61            payload_type: PhantomData,
 62        }
 63    }
 64}
 65
 66pub struct Peer {
 67    connections: RwLock<HashMap<ConnectionId, Connection>>,
 68    message_handlers: RwLock<Vec<MessageHandler>>,
 69    handler_types: Mutex<HashSet<TypeId>>,
 70    next_connection_id: AtomicU32,
 71}
 72
 73#[derive(Clone)]
 74struct Connection {
 75    outgoing_tx: mpsc::Sender<proto::Envelope>,
 76    next_message_id: Arc<AtomicU32>,
 77    response_channels: ResponseChannels,
 78}
 79
 80pub struct ConnectionHandler<W, R> {
 81    peer: Arc<Peer>,
 82    connection_id: ConnectionId,
 83    response_channels: ResponseChannels,
 84    outgoing_rx: mpsc::Receiver<proto::Envelope>,
 85    writer: MessageStream<W>,
 86    reader: MessageStream<R>,
 87}
 88
 89type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
 90
 91impl Peer {
 92    pub fn new() -> Arc<Self> {
 93        Arc::new(Self {
 94            connections: Default::default(),
 95            message_handlers: Default::default(),
 96            handler_types: Default::default(),
 97            next_connection_id: Default::default(),
 98        })
 99    }
100
101    pub async fn add_message_handler<T: EnvelopedMessage>(
102        &self,
103    ) -> mpsc::Receiver<TypedEnvelope<T>> {
104        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
105            panic!("duplicate handler type");
106        }
107
108        let (tx, rx) = mpsc::channel(256);
109        self.message_handlers
110            .write()
111            .await
112            .push(Box::new(move |envelope, connection_id| {
113                if envelope.as_ref().map_or(false, T::matches_envelope) {
114                    let envelope = Option::take(envelope).unwrap();
115                    let mut tx = tx.clone();
116                    Some(
117                        async move {
118                            tx.send(TypedEnvelope {
119                                sender_id: connection_id,
120                                original_sender_id: envelope.original_sender_id.map(PeerId),
121                                message_id: envelope.id,
122                                payload: T::from_envelope(envelope).unwrap(),
123                            })
124                            .await
125                            .is_err()
126                        }
127                        .boxed(),
128                    )
129                } else {
130                    None
131                }
132            }));
133        rx
134    }
135
136    pub async fn add_connection<Conn>(
137        self: &Arc<Self>,
138        conn: Conn,
139    ) -> (
140        ConnectionId,
141        ConnectionHandler<SplitSink<Conn, WebSocketMessage>, SplitStream<Conn>>,
142    )
143    where
144        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
145            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
146            + Unpin,
147    {
148        let (tx, rx) = conn.split();
149        let connection_id = ConnectionId(
150            self.next_connection_id
151                .fetch_add(1, atomic::Ordering::SeqCst),
152        );
153        let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
154        let connection = Connection {
155            outgoing_tx,
156            next_message_id: Default::default(),
157            response_channels: Default::default(),
158        };
159        let handler = ConnectionHandler {
160            peer: self.clone(),
161            connection_id,
162            response_channels: connection.response_channels.clone(),
163            outgoing_rx,
164            writer: MessageStream::new(tx),
165            reader: MessageStream::new(rx),
166        };
167        self.connections
168            .write()
169            .await
170            .insert(connection_id, connection);
171        (connection_id, handler)
172    }
173
174    pub async fn disconnect(&self, connection_id: ConnectionId) {
175        self.connections.write().await.remove(&connection_id);
176    }
177
178    pub async fn reset(&self) {
179        self.connections.write().await.clear();
180        self.handler_types.lock().await.clear();
181        self.message_handlers.write().await.clear();
182    }
183
184    pub fn request<T: RequestMessage>(
185        self: &Arc<Self>,
186        receiver_id: ConnectionId,
187        request: T,
188    ) -> impl Future<Output = Result<T::Response>> {
189        self.request_internal(None, receiver_id, request)
190    }
191
192    pub fn forward_request<T: RequestMessage>(
193        self: &Arc<Self>,
194        sender_id: ConnectionId,
195        receiver_id: ConnectionId,
196        request: T,
197    ) -> impl Future<Output = Result<T::Response>> {
198        self.request_internal(Some(sender_id), receiver_id, request)
199    }
200
201    pub fn request_internal<T: RequestMessage>(
202        self: &Arc<Self>,
203        original_sender_id: Option<ConnectionId>,
204        receiver_id: ConnectionId,
205        request: T,
206    ) -> impl Future<Output = Result<T::Response>> {
207        let this = self.clone();
208        let (tx, mut rx) = mpsc::channel(1);
209        async move {
210            let mut connection = this.connection(receiver_id).await?;
211            let message_id = connection
212                .next_message_id
213                .fetch_add(1, atomic::Ordering::SeqCst);
214            connection
215                .response_channels
216                .lock()
217                .await
218                .insert(message_id, tx);
219            connection
220                .outgoing_tx
221                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
222                .await?;
223            let response = rx
224                .recv()
225                .await
226                .ok_or_else(|| anyhow!("connection was closed"))?;
227            T::Response::from_envelope(response)
228                .ok_or_else(|| anyhow!("received response of the wrong type"))
229        }
230    }
231
232    pub fn send<T: EnvelopedMessage>(
233        self: &Arc<Self>,
234        receiver_id: ConnectionId,
235        message: T,
236    ) -> impl Future<Output = Result<()>> {
237        let this = self.clone();
238        async move {
239            let mut connection = this.connection(receiver_id).await?;
240            let message_id = connection
241                .next_message_id
242                .fetch_add(1, atomic::Ordering::SeqCst);
243            connection
244                .outgoing_tx
245                .send(message.into_envelope(message_id, None, None))
246                .await?;
247            Ok(())
248        }
249    }
250
251    pub fn forward_send<T: EnvelopedMessage>(
252        self: &Arc<Self>,
253        sender_id: ConnectionId,
254        receiver_id: ConnectionId,
255        message: T,
256    ) -> impl Future<Output = Result<()>> {
257        let this = self.clone();
258        async move {
259            let mut connection = this.connection(receiver_id).await?;
260            let message_id = connection
261                .next_message_id
262                .fetch_add(1, atomic::Ordering::SeqCst);
263            connection
264                .outgoing_tx
265                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
266                .await?;
267            Ok(())
268        }
269    }
270
271    pub fn respond<T: RequestMessage>(
272        self: &Arc<Self>,
273        receipt: Receipt<T>,
274        response: T::Response,
275    ) -> impl Future<Output = Result<()>> {
276        let this = self.clone();
277        async move {
278            let mut connection = this.connection(receipt.sender_id).await?;
279            let message_id = connection
280                .next_message_id
281                .fetch_add(1, atomic::Ordering::SeqCst);
282            connection
283                .outgoing_tx
284                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
285                .await?;
286            Ok(())
287        }
288    }
289
290    fn connection(
291        self: &Arc<Self>,
292        connection_id: ConnectionId,
293    ) -> impl Future<Output = Result<Connection>> {
294        let this = self.clone();
295        async move {
296            let connections = this.connections.read().await;
297            let connection = connections
298                .get(&connection_id)
299                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
300            Ok(connection.clone())
301        }
302    }
303}
304
305impl<W, R> ConnectionHandler<W, R>
306where
307    W: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
308    R: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
309{
310    pub async fn run(mut self) -> Result<()> {
311        loop {
312            let read_message = self.reader.read_message().fuse();
313            futures::pin_mut!(read_message);
314            loop {
315                futures::select_biased! {
316                    incoming = read_message => match incoming {
317                        Ok(incoming) => {
318                            Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
319                            break;
320                        }
321                        Err(error) => {
322                            self.response_channels.lock().await.clear();
323                            Err(error).context("received invalid RPC message")?;
324                        }
325                    },
326                    outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
327                        Some(outgoing) => {
328                            if let Err(result) = self.writer.write_message(&outgoing).await {
329                                self.response_channels.lock().await.clear();
330                                Err(result).context("failed to write RPC message")?;
331                            }
332                        }
333                        None => return Ok(()),
334                    }
335                }
336            }
337        }
338    }
339
340    pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
341        let envelope = self.reader.read_message().await?;
342        let original_sender_id = envelope.original_sender_id;
343        let message_id = envelope.id;
344        let payload =
345            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
346        Ok(TypedEnvelope {
347            sender_id: self.connection_id,
348            original_sender_id: original_sender_id.map(PeerId),
349            message_id,
350            payload,
351        })
352    }
353
354    async fn handle_incoming_message(
355        message: proto::Envelope,
356        peer: &Arc<Peer>,
357        connection_id: ConnectionId,
358        response_channels: &ResponseChannels,
359    ) {
360        if let Some(responding_to) = message.responding_to {
361            let channel = response_channels.lock().await.remove(&responding_to);
362            if let Some(mut tx) = channel {
363                tx.send(message).await.ok();
364            } else {
365                log::warn!("received RPC response to unknown request {}", responding_to);
366            }
367        } else {
368            let mut envelope = Some(message);
369            let mut handler_index = None;
370            let mut handler_was_dropped = false;
371            for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
372                if let Some(future) = handler(&mut envelope, connection_id) {
373                    handler_was_dropped = future.await;
374                    handler_index = Some(i);
375                    break;
376                }
377            }
378
379            if let Some(handler_index) = handler_index {
380                if handler_was_dropped {
381                    drop(peer.message_handlers.write().await.remove(handler_index));
382                }
383            } else {
384                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
385            }
386        }
387    }
388}
389
390impl<T> Clone for Receipt<T> {
391    fn clone(&self) -> Self {
392        Self {
393            sender_id: self.sender_id,
394            message_id: self.message_id,
395            payload_type: PhantomData,
396        }
397    }
398}
399
400impl<T> Copy for Receipt<T> {}
401
402impl fmt::Display for ConnectionId {
403    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404        self.0.fmt(f)
405    }
406}
407
408impl fmt::Display for PeerId {
409    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
410        self.0.fmt(f)
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use crate::test;
418    use postage::oneshot;
419
420    #[test]
421    fn test_request_response() {
422        smol::block_on(async move {
423            // create 2 clients connected to 1 server
424            let server = Peer::new();
425            let client1 = Peer::new();
426            let client2 = Peer::new();
427
428            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
429            let (client1_conn_id, task1) = client1.add_connection(client1_to_server_conn).await;
430            let (_, task2) = server.add_connection(server_to_client_1_conn).await;
431
432            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
433            let (client2_conn_id, task3) = client2.add_connection(client2_to_server_conn).await;
434            let (_, task4) = server.add_connection(server_to_client_2_conn).await;
435
436            smol::spawn(task1.run()).detach();
437            smol::spawn(task2.run()).detach();
438            smol::spawn(task3.run()).detach();
439            smol::spawn(task4.run()).detach();
440
441            // define the expected requests and responses
442            let request1 = proto::Auth {
443                user_id: 1,
444                access_token: "token-1".to_string(),
445            };
446            let response1 = proto::AuthResponse {
447                credentials_valid: true,
448            };
449            let request2 = proto::Auth {
450                user_id: 2,
451                access_token: "token-2".to_string(),
452            };
453            let response2 = proto::AuthResponse {
454                credentials_valid: false,
455            };
456            let request3 = proto::OpenBuffer {
457                worktree_id: 1,
458                path: "path/two".to_string(),
459            };
460            let response3 = proto::OpenBufferResponse {
461                buffer: Some(proto::Buffer {
462                    id: 2,
463                    content: "path/two content".to_string(),
464                    history: vec![],
465                    selections: vec![],
466                }),
467            };
468            let request4 = proto::OpenBuffer {
469                worktree_id: 2,
470                path: "path/one".to_string(),
471            };
472            let response4 = proto::OpenBufferResponse {
473                buffer: Some(proto::Buffer {
474                    id: 1,
475                    content: "path/one content".to_string(),
476                    history: vec![],
477                    selections: vec![],
478                }),
479            };
480
481            // on the server, respond to two requests for each client
482            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
483            let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
484            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
485            smol::spawn({
486                let request1 = request1.clone();
487                let request2 = request2.clone();
488                let request3 = request3.clone();
489                let request4 = request4.clone();
490                let response1 = response1.clone();
491                let response2 = response2.clone();
492                let response3 = response3.clone();
493                let response4 = response4.clone();
494                async move {
495                    let msg = auth_rx.recv().await.unwrap();
496                    assert_eq!(msg.payload, request1);
497                    server
498                        .respond(msg.receipt(), response1.clone())
499                        .await
500                        .unwrap();
501
502                    let msg = auth_rx.recv().await.unwrap();
503                    assert_eq!(msg.payload, request2.clone());
504                    server
505                        .respond(msg.receipt(), response2.clone())
506                        .await
507                        .unwrap();
508
509                    let msg = open_buffer_rx.recv().await.unwrap();
510                    assert_eq!(msg.payload, request3.clone());
511                    server
512                        .respond(msg.receipt(), response3.clone())
513                        .await
514                        .unwrap();
515
516                    let msg = open_buffer_rx.recv().await.unwrap();
517                    assert_eq!(msg.payload, request4.clone());
518                    server
519                        .respond(msg.receipt(), response4.clone())
520                        .await
521                        .unwrap();
522
523                    server_done_tx.send(()).await.unwrap();
524                }
525            })
526            .detach();
527
528            assert_eq!(
529                client1.request(client1_conn_id, request1).await.unwrap(),
530                response1
531            );
532            assert_eq!(
533                client2.request(client2_conn_id, request2).await.unwrap(),
534                response2
535            );
536            assert_eq!(
537                client2.request(client2_conn_id, request3).await.unwrap(),
538                response3
539            );
540            assert_eq!(
541                client1.request(client1_conn_id, request4).await.unwrap(),
542                response4
543            );
544
545            client1.disconnect(client1_conn_id).await;
546            client2.disconnect(client1_conn_id).await;
547
548            server_done_rx.recv().await.unwrap();
549        });
550    }
551
552    #[test]
553    fn test_disconnect() {
554        smol::block_on(async move {
555            let (client_conn, mut server_conn) = test::Channel::bidirectional();
556
557            let client = Peer::new();
558            let (connection_id, handler) = client.add_connection(client_conn).await;
559            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
560                postage::barrier::channel();
561            smol::spawn(async move {
562                handler.run().await.ok();
563                incoming_messages_ended_tx.send(()).await.unwrap();
564            })
565            .detach();
566            client.disconnect(connection_id).await;
567
568            incoming_messages_ended_rx.recv().await;
569            assert!(
570                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
571                    .await
572                    .is_err()
573            );
574        });
575    }
576
577    #[test]
578    fn test_io_error() {
579        smol::block_on(async move {
580            let (client_conn, server_conn) = test::Channel::bidirectional();
581            drop(server_conn);
582
583            let client = Peer::new();
584            let (connection_id, handler) = client.add_connection(client_conn).await;
585            smol::spawn(handler.run()).detach();
586
587            let err = client
588                .request(
589                    connection_id,
590                    proto::Auth {
591                        user_id: 42,
592                        access_token: "token".to_string(),
593                    },
594                )
595                .await
596                .unwrap_err();
597            assert_eq!(err.to_string(), "connection was closed");
598        });
599    }
600}