peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Conn;
  3use anyhow::{anyhow, Context, Result};
  4use async_lock::{Mutex, RwLock};
  5use futures::FutureExt as _;
  6use postage::{
  7    mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use std::{
 11    collections::HashMap,
 12    fmt,
 13    future::Future,
 14    marker::PhantomData,
 15    sync::{
 16        atomic::{self, AtomicU32},
 17        Arc,
 18    },
 19};
 20
 21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 22pub struct ConnectionId(pub u32);
 23
 24impl fmt::Display for ConnectionId {
 25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 26        self.0.fmt(f)
 27    }
 28}
 29
 30#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 31pub struct PeerId(pub u32);
 32
 33impl fmt::Display for PeerId {
 34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 35        self.0.fmt(f)
 36    }
 37}
 38
 39pub struct Receipt<T> {
 40    pub sender_id: ConnectionId,
 41    pub message_id: u32,
 42    payload_type: PhantomData<T>,
 43}
 44
 45impl<T> Clone for Receipt<T> {
 46    fn clone(&self) -> Self {
 47        Self {
 48            sender_id: self.sender_id,
 49            message_id: self.message_id,
 50            payload_type: PhantomData,
 51        }
 52    }
 53}
 54
 55impl<T> Copy for Receipt<T> {}
 56
 57pub struct TypedEnvelope<T> {
 58    pub sender_id: ConnectionId,
 59    pub original_sender_id: Option<PeerId>,
 60    pub message_id: u32,
 61    pub payload: T,
 62}
 63
 64impl<T> TypedEnvelope<T> {
 65    pub fn original_sender_id(&self) -> Result<PeerId> {
 66        self.original_sender_id
 67            .ok_or_else(|| anyhow!("missing original_sender_id"))
 68    }
 69}
 70
 71impl<T: RequestMessage> TypedEnvelope<T> {
 72    pub fn receipt(&self) -> Receipt<T> {
 73        Receipt {
 74            sender_id: self.sender_id,
 75            message_id: self.message_id,
 76            payload_type: PhantomData,
 77        }
 78    }
 79}
 80
 81pub struct Peer {
 82    connections: RwLock<HashMap<ConnectionId, Connection>>,
 83    next_connection_id: AtomicU32,
 84}
 85
 86#[derive(Clone)]
 87struct Connection {
 88    outgoing_tx: mpsc::Sender<proto::Envelope>,
 89    next_message_id: Arc<AtomicU32>,
 90    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 91}
 92
 93impl Peer {
 94    pub fn new() -> Arc<Self> {
 95        Arc::new(Self {
 96            connections: Default::default(),
 97            next_connection_id: Default::default(),
 98        })
 99    }
100
101    pub async fn add_connection(
102        self: &Arc<Self>,
103        conn: Conn,
104    ) -> (
105        ConnectionId,
106        impl Future<Output = anyhow::Result<()>> + Send,
107        mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
108    ) {
109        let connection_id = ConnectionId(
110            self.next_connection_id
111                .fetch_add(1, atomic::Ordering::SeqCst),
112        );
113        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
114        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
115        let connection = Connection {
116            outgoing_tx,
117            next_message_id: Default::default(),
118            response_channels: Default::default(),
119        };
120        let mut writer = MessageStream::new(conn.tx);
121        let mut reader = MessageStream::new(conn.rx);
122
123        let this = self.clone();
124        let response_channels = connection.response_channels.clone();
125        let handle_io = async move {
126            loop {
127                let read_message = reader.read_message().fuse();
128                futures::pin_mut!(read_message);
129                loop {
130                    futures::select_biased! {
131                        incoming = read_message => match incoming {
132                            Ok(incoming) => {
133                                if let Some(responding_to) = incoming.responding_to {
134                                    let channel = response_channels.lock().await.remove(&responding_to);
135                                    if let Some(mut tx) = channel {
136                                        tx.send(incoming).await.ok();
137                                    } else {
138                                        log::warn!("received RPC response to unknown request {}", responding_to);
139                                    }
140                                } else {
141                                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
142                                        if incoming_tx.send(envelope).await.is_err() {
143                                            response_channels.lock().await.clear();
144                                            this.connections.write().await.remove(&connection_id);
145                                            return Ok(())
146                                        }
147                                    } else {
148                                        log::error!("unable to construct a typed envelope");
149                                    }
150                                }
151
152                                break;
153                            }
154                            Err(error) => {
155                                response_channels.lock().await.clear();
156                                this.connections.write().await.remove(&connection_id);
157                                Err(error).context("received invalid RPC message")?;
158                            }
159                        },
160                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
161                            Some(outgoing) => {
162                                if let Err(result) = writer.write_message(&outgoing).await {
163                                    response_channels.lock().await.clear();
164                                    this.connections.write().await.remove(&connection_id);
165                                    Err(result).context("failed to write RPC message")?;
166                                }
167                            }
168                            None => {
169                                response_channels.lock().await.clear();
170                                this.connections.write().await.remove(&connection_id);
171                                return Ok(())
172                            }
173                        }
174                    }
175                }
176            }
177        };
178
179        self.connections
180            .write()
181            .await
182            .insert(connection_id, connection);
183
184        (connection_id, handle_io, incoming_rx)
185    }
186
187    pub async fn disconnect(&self, connection_id: ConnectionId) {
188        self.connections.write().await.remove(&connection_id);
189    }
190
191    pub async fn reset(&self) {
192        self.connections.write().await.clear();
193    }
194
195    pub fn request<T: RequestMessage>(
196        self: &Arc<Self>,
197        receiver_id: ConnectionId,
198        request: T,
199    ) -> impl Future<Output = Result<T::Response>> {
200        self.request_internal(None, receiver_id, request)
201    }
202
203    pub fn forward_request<T: RequestMessage>(
204        self: &Arc<Self>,
205        sender_id: ConnectionId,
206        receiver_id: ConnectionId,
207        request: T,
208    ) -> impl Future<Output = Result<T::Response>> {
209        self.request_internal(Some(sender_id), receiver_id, request)
210    }
211
212    pub fn request_internal<T: RequestMessage>(
213        self: &Arc<Self>,
214        original_sender_id: Option<ConnectionId>,
215        receiver_id: ConnectionId,
216        request: T,
217    ) -> impl Future<Output = Result<T::Response>> {
218        let this = self.clone();
219        let (tx, mut rx) = mpsc::channel(1);
220        async move {
221            let mut connection = this.connection(receiver_id).await?;
222            let message_id = connection
223                .next_message_id
224                .fetch_add(1, atomic::Ordering::SeqCst);
225            connection
226                .response_channels
227                .lock()
228                .await
229                .insert(message_id, tx);
230            connection
231                .outgoing_tx
232                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
233                .await
234                .map_err(|_| anyhow!("connection was closed"))?;
235            let response = rx
236                .recv()
237                .await
238                .ok_or_else(|| anyhow!("connection was closed"))?;
239            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
240                Err(anyhow!("request failed").context(error.message.clone()))
241            } else {
242                T::Response::from_envelope(response)
243                    .ok_or_else(|| anyhow!("received response of the wrong type"))
244            }
245        }
246    }
247
248    pub fn send<T: EnvelopedMessage>(
249        self: &Arc<Self>,
250        receiver_id: ConnectionId,
251        message: T,
252    ) -> impl Future<Output = Result<()>> {
253        let this = self.clone();
254        async move {
255            let mut connection = this.connection(receiver_id).await?;
256            let message_id = connection
257                .next_message_id
258                .fetch_add(1, atomic::Ordering::SeqCst);
259            connection
260                .outgoing_tx
261                .send(message.into_envelope(message_id, None, None))
262                .await?;
263            Ok(())
264        }
265    }
266
267    pub fn forward_send<T: EnvelopedMessage>(
268        self: &Arc<Self>,
269        sender_id: ConnectionId,
270        receiver_id: ConnectionId,
271        message: T,
272    ) -> impl Future<Output = Result<()>> {
273        let this = self.clone();
274        async move {
275            let mut connection = this.connection(receiver_id).await?;
276            let message_id = connection
277                .next_message_id
278                .fetch_add(1, atomic::Ordering::SeqCst);
279            connection
280                .outgoing_tx
281                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
282                .await?;
283            Ok(())
284        }
285    }
286
287    pub fn respond<T: RequestMessage>(
288        self: &Arc<Self>,
289        receipt: Receipt<T>,
290        response: T::Response,
291    ) -> impl Future<Output = Result<()>> {
292        let this = self.clone();
293        async move {
294            let mut connection = this.connection(receipt.sender_id).await?;
295            let message_id = connection
296                .next_message_id
297                .fetch_add(1, atomic::Ordering::SeqCst);
298            connection
299                .outgoing_tx
300                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
301                .await?;
302            Ok(())
303        }
304    }
305
306    pub fn respond_with_error<T: RequestMessage>(
307        self: &Arc<Self>,
308        receipt: Receipt<T>,
309        response: proto::Error,
310    ) -> impl Future<Output = Result<()>> {
311        let this = self.clone();
312        async move {
313            let mut connection = this.connection(receipt.sender_id).await?;
314            let message_id = connection
315                .next_message_id
316                .fetch_add(1, atomic::Ordering::SeqCst);
317            connection
318                .outgoing_tx
319                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
320                .await?;
321            Ok(())
322        }
323    }
324
325    fn connection(
326        self: &Arc<Self>,
327        connection_id: ConnectionId,
328    ) -> impl Future<Output = Result<Connection>> {
329        let this = self.clone();
330        async move {
331            let connections = this.connections.read().await;
332            let connection = connections
333                .get(&connection_id)
334                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
335            Ok(connection.clone())
336        }
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use crate::TypedEnvelope;
344    use async_tungstenite::tungstenite::Message as WebSocketMessage;
345    use futures::StreamExt as _;
346
347    #[test]
348    fn test_request_response() {
349        smol::block_on(async move {
350            // create 2 clients connected to 1 server
351            let server = Peer::new();
352            let client1 = Peer::new();
353            let client2 = Peer::new();
354
355            let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory();
356            let (client1_conn_id, io_task1, _) =
357                client1.add_connection(client1_to_server_conn).await;
358            let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
359
360            let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory();
361            let (client2_conn_id, io_task3, _) =
362                client2.add_connection(client2_to_server_conn).await;
363            let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
364
365            smol::spawn(io_task1).detach();
366            smol::spawn(io_task2).detach();
367            smol::spawn(io_task3).detach();
368            smol::spawn(io_task4).detach();
369            smol::spawn(handle_messages(incoming1, server.clone())).detach();
370            smol::spawn(handle_messages(incoming2, server.clone())).detach();
371
372            assert_eq!(
373                client1
374                    .request(client1_conn_id, proto::Ping { id: 1 },)
375                    .await
376                    .unwrap(),
377                proto::Pong { id: 1 }
378            );
379
380            assert_eq!(
381                client2
382                    .request(client2_conn_id, proto::Ping { id: 2 },)
383                    .await
384                    .unwrap(),
385                proto::Pong { id: 2 }
386            );
387
388            assert_eq!(
389                client1
390                    .request(
391                        client1_conn_id,
392                        proto::OpenBuffer {
393                            worktree_id: 1,
394                            path: "path/one".to_string(),
395                        },
396                    )
397                    .await
398                    .unwrap(),
399                proto::OpenBufferResponse {
400                    buffer: Some(proto::Buffer {
401                        id: 101,
402                        content: "path/one content".to_string(),
403                        history: vec![],
404                        selections: vec![],
405                    }),
406                }
407            );
408
409            assert_eq!(
410                client2
411                    .request(
412                        client2_conn_id,
413                        proto::OpenBuffer {
414                            worktree_id: 2,
415                            path: "path/two".to_string(),
416                        },
417                    )
418                    .await
419                    .unwrap(),
420                proto::OpenBufferResponse {
421                    buffer: Some(proto::Buffer {
422                        id: 102,
423                        content: "path/two content".to_string(),
424                        history: vec![],
425                        selections: vec![],
426                    }),
427                }
428            );
429
430            client1.disconnect(client1_conn_id).await;
431            client2.disconnect(client1_conn_id).await;
432
433            async fn handle_messages(
434                mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
435                peer: Arc<Peer>,
436            ) -> Result<()> {
437                while let Some(envelope) = messages.next().await {
438                    let envelope = envelope.into_any();
439                    if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
440                        let receipt = envelope.receipt();
441                        peer.respond(
442                            receipt,
443                            proto::Pong {
444                                id: envelope.payload.id,
445                            },
446                        )
447                        .await?
448                    } else if let Some(envelope) =
449                        envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
450                    {
451                        let message = &envelope.payload;
452                        let receipt = envelope.receipt();
453                        let response = match message.path.as_str() {
454                            "path/one" => {
455                                assert_eq!(message.worktree_id, 1);
456                                proto::OpenBufferResponse {
457                                    buffer: Some(proto::Buffer {
458                                        id: 101,
459                                        content: "path/one content".to_string(),
460                                        history: vec![],
461                                        selections: vec![],
462                                    }),
463                                }
464                            }
465                            "path/two" => {
466                                assert_eq!(message.worktree_id, 2);
467                                proto::OpenBufferResponse {
468                                    buffer: Some(proto::Buffer {
469                                        id: 102,
470                                        content: "path/two content".to_string(),
471                                        history: vec![],
472                                        selections: vec![],
473                                    }),
474                                }
475                            }
476                            _ => {
477                                panic!("unexpected path {}", message.path);
478                            }
479                        };
480
481                        peer.respond(receipt, response).await?
482                    } else {
483                        panic!("unknown message type");
484                    }
485                }
486
487                Ok(())
488            }
489        });
490    }
491
492    #[test]
493    fn test_disconnect() {
494        smol::block_on(async move {
495            let (client_conn, mut server_conn, _) = Conn::in_memory();
496
497            let client = Peer::new();
498            let (connection_id, io_handler, mut incoming) =
499                client.add_connection(client_conn).await;
500
501            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
502            smol::spawn(async move {
503                io_handler.await.ok();
504                io_ended_tx.send(()).await.unwrap();
505            })
506            .detach();
507
508            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
509            smol::spawn(async move {
510                incoming.next().await;
511                messages_ended_tx.send(()).await.unwrap();
512            })
513            .detach();
514
515            client.disconnect(connection_id).await;
516
517            io_ended_rx.recv().await;
518            messages_ended_rx.recv().await;
519            assert!(server_conn
520                .send(WebSocketMessage::Binary(vec![]))
521                .await
522                .is_err());
523        });
524    }
525
526    #[test]
527    fn test_io_error() {
528        smol::block_on(async move {
529            let (client_conn, server_conn, _) = Conn::in_memory();
530            drop(server_conn);
531
532            let client = Peer::new();
533            let (connection_id, io_handler, mut incoming) =
534                client.add_connection(client_conn).await;
535            smol::spawn(io_handler).detach();
536            smol::spawn(async move { incoming.next().await }).detach();
537
538            let err = client
539                .request(connection_id, proto::Ping { id: 42 })
540                .await
541                .unwrap_err();
542            assert_eq!(err.to_string(), "connection was closed");
543        });
544    }
545}