peer.rs

  1use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Context, Result};
  3use async_lock::{Mutex, RwLock};
  4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  5use futures::{FutureExt, StreamExt};
  6use postage::{
  7    mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use std::{
 11    collections::HashMap,
 12    fmt,
 13    future::Future,
 14    marker::PhantomData,
 15    sync::{
 16        atomic::{self, AtomicU32},
 17        Arc,
 18    },
 19};
 20
 21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 22pub struct ConnectionId(pub u32);
 23
 24impl fmt::Display for ConnectionId {
 25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 26        self.0.fmt(f)
 27    }
 28}
 29
 30#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 31pub struct PeerId(pub u32);
 32
 33impl fmt::Display for PeerId {
 34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 35        self.0.fmt(f)
 36    }
 37}
 38
 39pub struct Receipt<T> {
 40    pub sender_id: ConnectionId,
 41    pub message_id: u32,
 42    payload_type: PhantomData<T>,
 43}
 44
 45impl<T> Clone for Receipt<T> {
 46    fn clone(&self) -> Self {
 47        Self {
 48            sender_id: self.sender_id,
 49            message_id: self.message_id,
 50            payload_type: PhantomData,
 51        }
 52    }
 53}
 54
 55impl<T> Copy for Receipt<T> {}
 56
 57pub struct TypedEnvelope<T> {
 58    pub sender_id: ConnectionId,
 59    pub original_sender_id: Option<PeerId>,
 60    pub message_id: u32,
 61    pub payload: T,
 62}
 63
 64impl<T> TypedEnvelope<T> {
 65    pub fn original_sender_id(&self) -> Result<PeerId> {
 66        self.original_sender_id
 67            .ok_or_else(|| anyhow!("missing original_sender_id"))
 68    }
 69}
 70
 71impl<T: RequestMessage> TypedEnvelope<T> {
 72    pub fn receipt(&self) -> Receipt<T> {
 73        Receipt {
 74            sender_id: self.sender_id,
 75            message_id: self.message_id,
 76            payload_type: PhantomData,
 77        }
 78    }
 79}
 80
 81pub struct Peer {
 82    connections: RwLock<HashMap<ConnectionId, Connection>>,
 83    next_connection_id: AtomicU32,
 84}
 85
 86#[derive(Clone)]
 87struct Connection {
 88    outgoing_tx: mpsc::Sender<proto::Envelope>,
 89    next_message_id: Arc<AtomicU32>,
 90    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 91}
 92
 93impl Peer {
 94    pub fn new() -> Arc<Self> {
 95        Arc::new(Self {
 96            connections: Default::default(),
 97            next_connection_id: Default::default(),
 98        })
 99    }
100
101    pub async fn add_connection<Conn>(
102        self: &Arc<Self>,
103        conn: Conn,
104    ) -> (
105        ConnectionId,
106        impl Future<Output = anyhow::Result<()>> + Send,
107        mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
108    )
109    where
110        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
111            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
112            + Send
113            + Unpin,
114    {
115        let (tx, rx) = conn.split();
116        let connection_id = ConnectionId(
117            self.next_connection_id
118                .fetch_add(1, atomic::Ordering::SeqCst),
119        );
120        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
121        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
122        let connection = Connection {
123            outgoing_tx,
124            next_message_id: Default::default(),
125            response_channels: Default::default(),
126        };
127        let mut writer = MessageStream::new(tx);
128        let mut reader = MessageStream::new(rx);
129
130        let response_channels = connection.response_channels.clone();
131        let handle_io = async move {
132            loop {
133                let read_message = reader.read_message().fuse();
134                futures::pin_mut!(read_message);
135                loop {
136                    futures::select_biased! {
137                        incoming = read_message => match incoming {
138                            Ok(incoming) => {
139                                if let Some(responding_to) = incoming.responding_to {
140                                    let channel = response_channels.lock().await.remove(&responding_to);
141                                    if let Some(mut tx) = channel {
142                                        tx.send(incoming).await.ok();
143                                    } else {
144                                        log::warn!("received RPC response to unknown request {}", responding_to);
145                                    }
146                                } else {
147                                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
148                                        if incoming_tx.send(envelope).await.is_err() {
149                                            response_channels.lock().await.clear();
150                                            return Ok(())
151                                        }
152                                    } else {
153                                        log::error!("unable to construct a typed envelope");
154                                    }
155                                }
156
157                                break;
158                            }
159                            Err(error) => {
160                                response_channels.lock().await.clear();
161                                Err(error).context("received invalid RPC message")?;
162                            }
163                        },
164                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
165                            Some(outgoing) => {
166                                if let Err(result) = writer.write_message(&outgoing).await {
167                                    response_channels.lock().await.clear();
168                                    Err(result).context("failed to write RPC message")?;
169                                }
170                            }
171                            None => {
172                                response_channels.lock().await.clear();
173                                return Ok(())
174                            }
175                        }
176                    }
177                }
178            }
179        };
180
181        self.connections
182            .write()
183            .await
184            .insert(connection_id, connection);
185
186        (connection_id, handle_io, incoming_rx)
187    }
188
189    pub async fn disconnect(&self, connection_id: ConnectionId) {
190        self.connections.write().await.remove(&connection_id);
191    }
192
193    pub async fn reset(&self) {
194        self.connections.write().await.clear();
195    }
196
197    pub fn request<T: RequestMessage>(
198        self: &Arc<Self>,
199        receiver_id: ConnectionId,
200        request: T,
201    ) -> impl Future<Output = Result<T::Response>> {
202        self.request_internal(None, receiver_id, request)
203    }
204
205    pub fn forward_request<T: RequestMessage>(
206        self: &Arc<Self>,
207        sender_id: ConnectionId,
208        receiver_id: ConnectionId,
209        request: T,
210    ) -> impl Future<Output = Result<T::Response>> {
211        self.request_internal(Some(sender_id), receiver_id, request)
212    }
213
214    pub fn request_internal<T: RequestMessage>(
215        self: &Arc<Self>,
216        original_sender_id: Option<ConnectionId>,
217        receiver_id: ConnectionId,
218        request: T,
219    ) -> impl Future<Output = Result<T::Response>> {
220        let this = self.clone();
221        let (tx, mut rx) = mpsc::channel(1);
222        async move {
223            let mut connection = this.connection(receiver_id).await?;
224            let message_id = connection
225                .next_message_id
226                .fetch_add(1, atomic::Ordering::SeqCst);
227            connection
228                .response_channels
229                .lock()
230                .await
231                .insert(message_id, tx);
232            connection
233                .outgoing_tx
234                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
235                .await
236                .map_err(|_| anyhow!("connection was closed"))?;
237            let response = rx
238                .recv()
239                .await
240                .ok_or_else(|| anyhow!("connection was closed"))?;
241            T::Response::from_envelope(response)
242                .ok_or_else(|| anyhow!("received response of the wrong type"))
243        }
244    }
245
246    pub fn send<T: EnvelopedMessage>(
247        self: &Arc<Self>,
248        receiver_id: ConnectionId,
249        message: T,
250    ) -> impl Future<Output = Result<()>> {
251        let this = self.clone();
252        async move {
253            let mut connection = this.connection(receiver_id).await?;
254            let message_id = connection
255                .next_message_id
256                .fetch_add(1, atomic::Ordering::SeqCst);
257            connection
258                .outgoing_tx
259                .send(message.into_envelope(message_id, None, None))
260                .await?;
261            Ok(())
262        }
263    }
264
265    pub fn forward_send<T: EnvelopedMessage>(
266        self: &Arc<Self>,
267        sender_id: ConnectionId,
268        receiver_id: ConnectionId,
269        message: T,
270    ) -> impl Future<Output = Result<()>> {
271        let this = self.clone();
272        async move {
273            let mut connection = this.connection(receiver_id).await?;
274            let message_id = connection
275                .next_message_id
276                .fetch_add(1, atomic::Ordering::SeqCst);
277            connection
278                .outgoing_tx
279                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
280                .await?;
281            Ok(())
282        }
283    }
284
285    pub fn respond<T: RequestMessage>(
286        self: &Arc<Self>,
287        receipt: Receipt<T>,
288        response: T::Response,
289    ) -> impl Future<Output = Result<()>> {
290        let this = self.clone();
291        async move {
292            let mut connection = this.connection(receipt.sender_id).await?;
293            let message_id = connection
294                .next_message_id
295                .fetch_add(1, atomic::Ordering::SeqCst);
296            connection
297                .outgoing_tx
298                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
299                .await?;
300            Ok(())
301        }
302    }
303
304    fn connection(
305        self: &Arc<Self>,
306        connection_id: ConnectionId,
307    ) -> impl Future<Output = Result<Connection>> {
308        let this = self.clone();
309        async move {
310            let connections = this.connections.read().await;
311            let connection = connections
312                .get(&connection_id)
313                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
314            Ok(connection.clone())
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::{test, TypedEnvelope};
323
324    #[test]
325    fn test_request_response() {
326        smol::block_on(async move {
327            // create 2 clients connected to 1 server
328            let server = Peer::new();
329            let client1 = Peer::new();
330            let client2 = Peer::new();
331
332            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
333            let (client1_conn_id, io_task1, _) =
334                client1.add_connection(client1_to_server_conn).await;
335            let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
336
337            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
338            let (client2_conn_id, io_task3, _) =
339                client2.add_connection(client2_to_server_conn).await;
340            let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
341
342            smol::spawn(io_task1).detach();
343            smol::spawn(io_task2).detach();
344            smol::spawn(io_task3).detach();
345            smol::spawn(io_task4).detach();
346            smol::spawn(handle_messages(incoming1, server.clone())).detach();
347            smol::spawn(handle_messages(incoming2, server.clone())).detach();
348
349            assert_eq!(
350                client1
351                    .request(client1_conn_id, proto::Ping { id: 1 },)
352                    .await
353                    .unwrap(),
354                proto::Pong { id: 1 }
355            );
356
357            assert_eq!(
358                client2
359                    .request(client2_conn_id, proto::Ping { id: 2 },)
360                    .await
361                    .unwrap(),
362                proto::Pong { id: 2 }
363            );
364
365            assert_eq!(
366                client1
367                    .request(
368                        client1_conn_id,
369                        proto::OpenBuffer {
370                            worktree_id: 1,
371                            path: "path/one".to_string(),
372                        },
373                    )
374                    .await
375                    .unwrap(),
376                proto::OpenBufferResponse {
377                    buffer: Some(proto::Buffer {
378                        id: 101,
379                        content: "path/one content".to_string(),
380                        history: vec![],
381                        selections: vec![],
382                    }),
383                }
384            );
385
386            assert_eq!(
387                client2
388                    .request(
389                        client2_conn_id,
390                        proto::OpenBuffer {
391                            worktree_id: 2,
392                            path: "path/two".to_string(),
393                        },
394                    )
395                    .await
396                    .unwrap(),
397                proto::OpenBufferResponse {
398                    buffer: Some(proto::Buffer {
399                        id: 102,
400                        content: "path/two content".to_string(),
401                        history: vec![],
402                        selections: vec![],
403                    }),
404                }
405            );
406
407            client1.disconnect(client1_conn_id).await;
408            client2.disconnect(client1_conn_id).await;
409
410            async fn handle_messages(
411                mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
412                peer: Arc<Peer>,
413            ) -> Result<()> {
414                while let Some(envelope) = messages.next().await {
415                    let envelope = envelope.into_any();
416                    if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
417                        let receipt = envelope.receipt();
418                        peer.respond(
419                            receipt,
420                            proto::Pong {
421                                id: envelope.payload.id,
422                            },
423                        )
424                        .await?
425                    } else if let Some(envelope) =
426                        envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
427                    {
428                        let message = &envelope.payload;
429                        let receipt = envelope.receipt();
430                        let response = match message.path.as_str() {
431                            "path/one" => {
432                                assert_eq!(message.worktree_id, 1);
433                                proto::OpenBufferResponse {
434                                    buffer: Some(proto::Buffer {
435                                        id: 101,
436                                        content: "path/one content".to_string(),
437                                        history: vec![],
438                                        selections: vec![],
439                                    }),
440                                }
441                            }
442                            "path/two" => {
443                                assert_eq!(message.worktree_id, 2);
444                                proto::OpenBufferResponse {
445                                    buffer: Some(proto::Buffer {
446                                        id: 102,
447                                        content: "path/two content".to_string(),
448                                        history: vec![],
449                                        selections: vec![],
450                                    }),
451                                }
452                            }
453                            _ => {
454                                panic!("unexpected path {}", message.path);
455                            }
456                        };
457
458                        peer.respond(receipt, response).await?
459                    } else {
460                        panic!("unknown message type");
461                    }
462                }
463
464                Ok(())
465            }
466        });
467    }
468
469    #[test]
470    fn test_disconnect() {
471        smol::block_on(async move {
472            let (client_conn, mut server_conn) = test::Channel::bidirectional();
473
474            let client = Peer::new();
475            let (connection_id, io_handler, mut incoming) =
476                client.add_connection(client_conn).await;
477
478            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
479            smol::spawn(async move {
480                io_handler.await.ok();
481                io_ended_tx.send(()).await.unwrap();
482            })
483            .detach();
484
485            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
486            smol::spawn(async move {
487                incoming.next().await;
488                messages_ended_tx.send(()).await.unwrap();
489            })
490            .detach();
491
492            client.disconnect(connection_id).await;
493
494            io_ended_rx.recv().await;
495            messages_ended_rx.recv().await;
496            assert!(
497                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
498                    .await
499                    .is_err()
500            );
501        });
502    }
503
504    #[test]
505    fn test_io_error() {
506        smol::block_on(async move {
507            let (client_conn, server_conn) = test::Channel::bidirectional();
508            drop(server_conn);
509
510            let client = Peer::new();
511            let (connection_id, io_handler, mut incoming) =
512                client.add_connection(client_conn).await;
513            smol::spawn(io_handler).detach();
514            smol::spawn(async move { incoming.next().await }).detach();
515
516            let err = client
517                .request(connection_id, proto::Ping { id: 42 })
518                .await
519                .unwrap_err();
520            assert_eq!(err.to_string(), "connection was closed");
521        });
522    }
523}