peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Context, Result};
  3use async_lock::{Mutex, RwLock};
  4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  5use futures::{FutureExt, StreamExt};
  6use postage::{
  7    mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use std::{
 11    any::Any,
 12    collections::HashMap,
 13    fmt,
 14    future::Future,
 15    marker::PhantomData,
 16    sync::{
 17        atomic::{self, AtomicU32},
 18        Arc,
 19    },
 20};
 21
 22#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 23pub struct ConnectionId(pub u32);
 24
 25impl fmt::Display for ConnectionId {
 26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 27        self.0.fmt(f)
 28    }
 29}
 30
 31#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 32pub struct PeerId(pub u32);
 33
 34impl fmt::Display for PeerId {
 35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 36        self.0.fmt(f)
 37    }
 38}
 39
 40pub struct Receipt<T> {
 41    pub sender_id: ConnectionId,
 42    pub message_id: u32,
 43    payload_type: PhantomData<T>,
 44}
 45
 46impl<T> Clone for Receipt<T> {
 47    fn clone(&self) -> Self {
 48        Self {
 49            sender_id: self.sender_id,
 50            message_id: self.message_id,
 51            payload_type: PhantomData,
 52        }
 53    }
 54}
 55
 56impl<T> Copy for Receipt<T> {}
 57
 58pub struct TypedEnvelope<T> {
 59    pub sender_id: ConnectionId,
 60    pub original_sender_id: Option<PeerId>,
 61    pub message_id: u32,
 62    pub payload: T,
 63}
 64
 65impl<T> TypedEnvelope<T> {
 66    pub fn original_sender_id(&self) -> Result<PeerId> {
 67        self.original_sender_id
 68            .ok_or_else(|| anyhow!("missing original_sender_id"))
 69    }
 70}
 71
 72impl<T: RequestMessage> TypedEnvelope<T> {
 73    pub fn receipt(&self) -> Receipt<T> {
 74        Receipt {
 75            sender_id: self.sender_id,
 76            message_id: self.message_id,
 77            payload_type: PhantomData,
 78        }
 79    }
 80}
 81
 82pub struct Peer {
 83    connections: RwLock<HashMap<ConnectionId, Connection>>,
 84    next_connection_id: AtomicU32,
 85}
 86
 87#[derive(Clone)]
 88struct Connection {
 89    outgoing_tx: mpsc::Sender<proto::Envelope>,
 90    next_message_id: Arc<AtomicU32>,
 91    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 92}
 93
 94impl Peer {
 95    pub fn new() -> Arc<Self> {
 96        Arc::new(Self {
 97            connections: Default::default(),
 98            next_connection_id: Default::default(),
 99        })
100    }
101
102    pub async fn add_connection<Conn>(
103        self: &Arc<Self>,
104        conn: Conn,
105    ) -> (
106        ConnectionId,
107        impl Future<Output = anyhow::Result<()>> + Send,
108        mpsc::Receiver<Box<dyn Any + Sync + Send>>,
109    )
110    where
111        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
112            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
113            + Send
114            + Unpin,
115    {
116        let (tx, rx) = conn.split();
117        let connection_id = ConnectionId(
118            self.next_connection_id
119                .fetch_add(1, atomic::Ordering::SeqCst),
120        );
121        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
122        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
123        let connection = Connection {
124            outgoing_tx,
125            next_message_id: Default::default(),
126            response_channels: Default::default(),
127        };
128        let mut writer = MessageStream::new(tx);
129        let mut reader = MessageStream::new(rx);
130
131        let response_channels = connection.response_channels.clone();
132        let handle_io = async move {
133            loop {
134                let read_message = reader.read_message().fuse();
135                futures::pin_mut!(read_message);
136                loop {
137                    futures::select_biased! {
138                        incoming = read_message => match incoming {
139                            Ok(incoming) => {
140                                if let Some(responding_to) = incoming.responding_to {
141                                    let channel = response_channels.lock().await.remove(&responding_to);
142                                    if let Some(mut tx) = channel {
143                                        tx.send(incoming).await.ok();
144                                    } else {
145                                        log::warn!("received RPC response to unknown request {}", responding_to);
146                                    }
147                                } else {
148                                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
149                                        if incoming_tx.send(envelope).await.is_err() {
150                                            response_channels.lock().await.clear();
151                                            return Ok(())
152                                        }
153                                    } else {
154                                        log::error!("unable to construct a typed envelope");
155                                    }
156                                }
157
158                                break;
159                            }
160                            Err(error) => {
161                                response_channels.lock().await.clear();
162                                Err(error).context("received invalid RPC message")?;
163                            }
164                        },
165                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
166                            Some(outgoing) => {
167                                if let Err(result) = writer.write_message(&outgoing).await {
168                                    response_channels.lock().await.clear();
169                                    Err(result).context("failed to write RPC message")?;
170                                }
171                            }
172                            None => {
173                                response_channels.lock().await.clear();
174                                return Ok(())
175                            }
176                        }
177                    }
178                }
179            }
180        };
181
182        self.connections
183            .write()
184            .await
185            .insert(connection_id, connection);
186
187        (connection_id, handle_io, incoming_rx)
188    }
189
190    pub async fn disconnect(&self, connection_id: ConnectionId) {
191        self.connections.write().await.remove(&connection_id);
192    }
193
194    pub async fn reset(&self) {
195        self.connections.write().await.clear();
196    }
197
198    pub fn request<T: RequestMessage>(
199        self: &Arc<Self>,
200        receiver_id: ConnectionId,
201        request: T,
202    ) -> impl Future<Output = Result<T::Response>> {
203        self.request_internal(None, receiver_id, request)
204    }
205
206    pub fn forward_request<T: RequestMessage>(
207        self: &Arc<Self>,
208        sender_id: ConnectionId,
209        receiver_id: ConnectionId,
210        request: T,
211    ) -> impl Future<Output = Result<T::Response>> {
212        self.request_internal(Some(sender_id), receiver_id, request)
213    }
214
215    pub fn request_internal<T: RequestMessage>(
216        self: &Arc<Self>,
217        original_sender_id: Option<ConnectionId>,
218        receiver_id: ConnectionId,
219        request: T,
220    ) -> impl Future<Output = Result<T::Response>> {
221        let this = self.clone();
222        let (tx, mut rx) = mpsc::channel(1);
223        async move {
224            let mut connection = this.connection(receiver_id).await?;
225            let message_id = connection
226                .next_message_id
227                .fetch_add(1, atomic::Ordering::SeqCst);
228            connection
229                .response_channels
230                .lock()
231                .await
232                .insert(message_id, tx);
233            connection
234                .outgoing_tx
235                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
236                .await
237                .map_err(|_| anyhow!("connection was closed"))?;
238            let response = rx
239                .recv()
240                .await
241                .ok_or_else(|| anyhow!("connection was closed"))?;
242            T::Response::from_envelope(response)
243                .ok_or_else(|| anyhow!("received response of the wrong type"))
244        }
245    }
246
247    pub fn send<T: EnvelopedMessage>(
248        self: &Arc<Self>,
249        receiver_id: ConnectionId,
250        message: T,
251    ) -> impl Future<Output = Result<()>> {
252        let this = self.clone();
253        async move {
254            let mut connection = this.connection(receiver_id).await?;
255            let message_id = connection
256                .next_message_id
257                .fetch_add(1, atomic::Ordering::SeqCst);
258            connection
259                .outgoing_tx
260                .send(message.into_envelope(message_id, None, None))
261                .await?;
262            Ok(())
263        }
264    }
265
266    pub fn forward_send<T: EnvelopedMessage>(
267        self: &Arc<Self>,
268        sender_id: ConnectionId,
269        receiver_id: ConnectionId,
270        message: T,
271    ) -> impl Future<Output = Result<()>> {
272        let this = self.clone();
273        async move {
274            let mut connection = this.connection(receiver_id).await?;
275            let message_id = connection
276                .next_message_id
277                .fetch_add(1, atomic::Ordering::SeqCst);
278            connection
279                .outgoing_tx
280                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
281                .await?;
282            Ok(())
283        }
284    }
285
286    pub fn respond<T: RequestMessage>(
287        self: &Arc<Self>,
288        receipt: Receipt<T>,
289        response: T::Response,
290    ) -> impl Future<Output = Result<()>> {
291        let this = self.clone();
292        async move {
293            let mut connection = this.connection(receipt.sender_id).await?;
294            let message_id = connection
295                .next_message_id
296                .fetch_add(1, atomic::Ordering::SeqCst);
297            connection
298                .outgoing_tx
299                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
300                .await?;
301            Ok(())
302        }
303    }
304
305    fn connection(
306        self: &Arc<Self>,
307        connection_id: ConnectionId,
308    ) -> impl Future<Output = Result<Connection>> {
309        let this = self.clone();
310        async move {
311            let connections = this.connections.read().await;
312            let connection = connections
313                .get(&connection_id)
314                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
315            Ok(connection.clone())
316        }
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::{test, TypedEnvelope};
324
325    #[test]
326    fn test_request_response() {
327        smol::block_on(async move {
328            // create 2 clients connected to 1 server
329            let server = Peer::new();
330            let client1 = Peer::new();
331            let client2 = Peer::new();
332
333            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
334            let (client1_conn_id, io_task1, _) =
335                client1.add_connection(client1_to_server_conn).await;
336            let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
337
338            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
339            let (client2_conn_id, io_task3, _) =
340                client2.add_connection(client2_to_server_conn).await;
341            let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
342
343            smol::spawn(io_task1).detach();
344            smol::spawn(io_task2).detach();
345            smol::spawn(io_task3).detach();
346            smol::spawn(io_task4).detach();
347            smol::spawn(handle_messages(incoming1, server.clone())).detach();
348            smol::spawn(handle_messages(incoming2, server.clone())).detach();
349
350            assert_eq!(
351                client1
352                    .request(client1_conn_id, proto::Ping { id: 1 },)
353                    .await
354                    .unwrap(),
355                proto::Pong { id: 1 }
356            );
357
358            assert_eq!(
359                client2
360                    .request(client2_conn_id, proto::Ping { id: 2 },)
361                    .await
362                    .unwrap(),
363                proto::Pong { id: 2 }
364            );
365
366            assert_eq!(
367                client1
368                    .request(
369                        client1_conn_id,
370                        proto::OpenBuffer {
371                            worktree_id: 1,
372                            path: "path/one".to_string(),
373                        },
374                    )
375                    .await
376                    .unwrap(),
377                proto::OpenBufferResponse {
378                    buffer: Some(proto::Buffer {
379                        id: 101,
380                        content: "path/one content".to_string(),
381                        history: vec![],
382                        selections: vec![],
383                    }),
384                }
385            );
386
387            assert_eq!(
388                client2
389                    .request(
390                        client2_conn_id,
391                        proto::OpenBuffer {
392                            worktree_id: 2,
393                            path: "path/two".to_string(),
394                        },
395                    )
396                    .await
397                    .unwrap(),
398                proto::OpenBufferResponse {
399                    buffer: Some(proto::Buffer {
400                        id: 102,
401                        content: "path/two content".to_string(),
402                        history: vec![],
403                        selections: vec![],
404                    }),
405                }
406            );
407
408            client1.disconnect(client1_conn_id).await;
409            client2.disconnect(client1_conn_id).await;
410
411            async fn handle_messages(
412                mut messages: mpsc::Receiver<Box<dyn Any + Sync + Send>>,
413                peer: Arc<Peer>,
414            ) -> Result<()> {
415                while let Some(envelope) = messages.next().await {
416                    if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
417                        let receipt = envelope.receipt();
418                        peer.respond(
419                            receipt,
420                            proto::Pong {
421                                id: envelope.payload.id,
422                            },
423                        )
424                        .await?
425                    } else if let Some(envelope) =
426                        envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
427                    {
428                        let message = &envelope.payload;
429                        let receipt = envelope.receipt();
430                        let response = match message.path.as_str() {
431                            "path/one" => {
432                                assert_eq!(message.worktree_id, 1);
433                                proto::OpenBufferResponse {
434                                    buffer: Some(proto::Buffer {
435                                        id: 101,
436                                        content: "path/one content".to_string(),
437                                        history: vec![],
438                                        selections: vec![],
439                                    }),
440                                }
441                            }
442                            "path/two" => {
443                                assert_eq!(message.worktree_id, 2);
444                                proto::OpenBufferResponse {
445                                    buffer: Some(proto::Buffer {
446                                        id: 102,
447                                        content: "path/two content".to_string(),
448                                        history: vec![],
449                                        selections: vec![],
450                                    }),
451                                }
452                            }
453                            _ => {
454                                panic!("unexpected path {}", message.path);
455                            }
456                        };
457
458                        peer.respond(receipt, response).await?
459                    } else {
460                        panic!("unknown message type");
461                    }
462                }
463
464                Ok(())
465            }
466        });
467    }
468
469    #[test]
470    fn test_disconnect() {
471        smol::block_on(async move {
472            let (client_conn, mut server_conn) = test::Channel::bidirectional();
473
474            let client = Peer::new();
475            let (connection_id, io_handler, mut incoming) =
476                client.add_connection(client_conn).await;
477
478            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
479            smol::spawn(async move {
480                io_handler.await.ok();
481                io_ended_tx.send(()).await.unwrap();
482            })
483            .detach();
484
485            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
486            smol::spawn(async move {
487                incoming.next().await;
488                messages_ended_tx.send(()).await.unwrap();
489            })
490            .detach();
491
492            client.disconnect(connection_id).await;
493
494            io_ended_rx.recv().await;
495            messages_ended_rx.recv().await;
496            assert!(
497                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
498                    .await
499                    .is_err()
500            );
501        });
502    }
503
504    #[test]
505    fn test_io_error() {
506        smol::block_on(async move {
507            let (client_conn, server_conn) = test::Channel::bidirectional();
508            drop(server_conn);
509
510            let client = Peer::new();
511            let (connection_id, io_handler, mut incoming) =
512                client.add_connection(client_conn).await;
513            smol::spawn(io_handler).detach();
514            smol::spawn(async move { incoming.next().await }).detach();
515
516            let err = client
517                .request(
518                    connection_id,
519                    proto::Auth {
520                        user_id: 42,
521                        access_token: "token".to_string(),
522                    },
523                )
524                .await
525                .unwrap_err();
526            assert_eq!(err.to_string(), "connection was closed");
527        });
528    }
529}