peer.rs

  1use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Context, Result};
  3use async_lock::{Mutex, RwLock};
  4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  5use futures::{FutureExt, StreamExt};
  6use postage::{
  7    mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use std::{
 11    collections::HashMap,
 12    fmt,
 13    future::Future,
 14    marker::PhantomData,
 15    sync::{
 16        atomic::{self, AtomicU32},
 17        Arc,
 18    },
 19};
 20
 21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 22pub struct ConnectionId(pub u32);
 23
 24impl fmt::Display for ConnectionId {
 25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 26        self.0.fmt(f)
 27    }
 28}
 29
 30#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 31pub struct PeerId(pub u32);
 32
 33impl fmt::Display for PeerId {
 34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 35        self.0.fmt(f)
 36    }
 37}
 38
 39pub struct Receipt<T> {
 40    pub sender_id: ConnectionId,
 41    pub message_id: u32,
 42    payload_type: PhantomData<T>,
 43}
 44
 45impl<T> Clone for Receipt<T> {
 46    fn clone(&self) -> Self {
 47        Self {
 48            sender_id: self.sender_id,
 49            message_id: self.message_id,
 50            payload_type: PhantomData,
 51        }
 52    }
 53}
 54
 55impl<T> Copy for Receipt<T> {}
 56
 57pub struct TypedEnvelope<T> {
 58    pub sender_id: ConnectionId,
 59    pub original_sender_id: Option<PeerId>,
 60    pub message_id: u32,
 61    pub payload: T,
 62}
 63
 64impl<T> TypedEnvelope<T> {
 65    pub fn original_sender_id(&self) -> Result<PeerId> {
 66        self.original_sender_id
 67            .ok_or_else(|| anyhow!("missing original_sender_id"))
 68    }
 69}
 70
 71impl<T: RequestMessage> TypedEnvelope<T> {
 72    pub fn receipt(&self) -> Receipt<T> {
 73        Receipt {
 74            sender_id: self.sender_id,
 75            message_id: self.message_id,
 76            payload_type: PhantomData,
 77        }
 78    }
 79}
 80
 81pub struct Peer {
 82    connections: RwLock<HashMap<ConnectionId, Connection>>,
 83    next_connection_id: AtomicU32,
 84}
 85
 86#[derive(Clone)]
 87struct Connection {
 88    outgoing_tx: mpsc::Sender<proto::Envelope>,
 89    next_message_id: Arc<AtomicU32>,
 90    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 91}
 92
 93impl Peer {
 94    pub fn new() -> Arc<Self> {
 95        Arc::new(Self {
 96            connections: Default::default(),
 97            next_connection_id: Default::default(),
 98        })
 99    }
100
101    pub async fn add_connection<Conn>(
102        self: &Arc<Self>,
103        conn: Conn,
104    ) -> (
105        ConnectionId,
106        impl Future<Output = anyhow::Result<()>> + Send,
107        mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
108    )
109    where
110        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
111            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
112            + Send
113            + Unpin,
114    {
115        let (tx, rx) = conn.split();
116        let connection_id = ConnectionId(
117            self.next_connection_id
118                .fetch_add(1, atomic::Ordering::SeqCst),
119        );
120        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
121        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
122        let connection = Connection {
123            outgoing_tx,
124            next_message_id: Default::default(),
125            response_channels: Default::default(),
126        };
127        let mut writer = MessageStream::new(tx);
128        let mut reader = MessageStream::new(rx);
129
130        let this = self.clone();
131        let response_channels = connection.response_channels.clone();
132        let handle_io = async move {
133            loop {
134                let read_message = reader.read_message().fuse();
135                futures::pin_mut!(read_message);
136                loop {
137                    futures::select_biased! {
138                        incoming = read_message => match incoming {
139                            Ok(incoming) => {
140                                if let Some(responding_to) = incoming.responding_to {
141                                    let channel = response_channels.lock().await.remove(&responding_to);
142                                    if let Some(mut tx) = channel {
143                                        tx.send(incoming).await.ok();
144                                    } else {
145                                        log::warn!("received RPC response to unknown request {}", responding_to);
146                                    }
147                                } else {
148                                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
149                                        if incoming_tx.send(envelope).await.is_err() {
150                                            response_channels.lock().await.clear();
151                                            this.connections.write().await.remove(&connection_id);
152                                            return Ok(())
153                                        }
154                                    } else {
155                                        log::error!("unable to construct a typed envelope");
156                                    }
157                                }
158
159                                break;
160                            }
161                            Err(error) => {
162                                response_channels.lock().await.clear();
163                                this.connections.write().await.remove(&connection_id);
164                                Err(error).context("received invalid RPC message")?;
165                            }
166                        },
167                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
168                            Some(outgoing) => {
169                                if let Err(result) = writer.write_message(&outgoing).await {
170                                    response_channels.lock().await.clear();
171                                    this.connections.write().await.remove(&connection_id);
172                                    Err(result).context("failed to write RPC message")?;
173                                }
174                            }
175                            None => {
176                                response_channels.lock().await.clear();
177                                this.connections.write().await.remove(&connection_id);
178                                return Ok(())
179                            }
180                        }
181                    }
182                }
183            }
184        };
185
186        self.connections
187            .write()
188            .await
189            .insert(connection_id, connection);
190
191        (connection_id, handle_io, incoming_rx)
192    }
193
194    pub async fn disconnect(&self, connection_id: ConnectionId) {
195        self.connections.write().await.remove(&connection_id);
196    }
197
198    pub async fn reset(&self) {
199        self.connections.write().await.clear();
200    }
201
202    pub fn request<T: RequestMessage>(
203        self: &Arc<Self>,
204        receiver_id: ConnectionId,
205        request: T,
206    ) -> impl Future<Output = Result<T::Response>> {
207        self.request_internal(None, receiver_id, request)
208    }
209
210    pub fn forward_request<T: RequestMessage>(
211        self: &Arc<Self>,
212        sender_id: ConnectionId,
213        receiver_id: ConnectionId,
214        request: T,
215    ) -> impl Future<Output = Result<T::Response>> {
216        self.request_internal(Some(sender_id), receiver_id, request)
217    }
218
219    pub fn request_internal<T: RequestMessage>(
220        self: &Arc<Self>,
221        original_sender_id: Option<ConnectionId>,
222        receiver_id: ConnectionId,
223        request: T,
224    ) -> impl Future<Output = Result<T::Response>> {
225        let this = self.clone();
226        let (tx, mut rx) = mpsc::channel(1);
227        async move {
228            let mut connection = this.connection(receiver_id).await?;
229            let message_id = connection
230                .next_message_id
231                .fetch_add(1, atomic::Ordering::SeqCst);
232            connection
233                .response_channels
234                .lock()
235                .await
236                .insert(message_id, tx);
237            connection
238                .outgoing_tx
239                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
240                .await
241                .map_err(|_| anyhow!("connection was closed"))?;
242            let response = rx
243                .recv()
244                .await
245                .ok_or_else(|| anyhow!("connection was closed"))?;
246            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
247                Err(anyhow!("request failed").context(error.message.clone()))
248            } else {
249                T::Response::from_envelope(response)
250                    .ok_or_else(|| anyhow!("received response of the wrong type"))
251            }
252        }
253    }
254
255    pub fn send<T: EnvelopedMessage>(
256        self: &Arc<Self>,
257        receiver_id: ConnectionId,
258        message: T,
259    ) -> impl Future<Output = Result<()>> {
260        let this = self.clone();
261        async move {
262            let mut connection = this.connection(receiver_id).await?;
263            let message_id = connection
264                .next_message_id
265                .fetch_add(1, atomic::Ordering::SeqCst);
266            connection
267                .outgoing_tx
268                .send(message.into_envelope(message_id, None, None))
269                .await?;
270            Ok(())
271        }
272    }
273
274    pub fn forward_send<T: EnvelopedMessage>(
275        self: &Arc<Self>,
276        sender_id: ConnectionId,
277        receiver_id: ConnectionId,
278        message: T,
279    ) -> impl Future<Output = Result<()>> {
280        let this = self.clone();
281        async move {
282            let mut connection = this.connection(receiver_id).await?;
283            let message_id = connection
284                .next_message_id
285                .fetch_add(1, atomic::Ordering::SeqCst);
286            connection
287                .outgoing_tx
288                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
289                .await?;
290            Ok(())
291        }
292    }
293
294    pub fn respond<T: RequestMessage>(
295        self: &Arc<Self>,
296        receipt: Receipt<T>,
297        response: T::Response,
298    ) -> impl Future<Output = Result<()>> {
299        let this = self.clone();
300        async move {
301            let mut connection = this.connection(receipt.sender_id).await?;
302            let message_id = connection
303                .next_message_id
304                .fetch_add(1, atomic::Ordering::SeqCst);
305            connection
306                .outgoing_tx
307                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
308                .await?;
309            Ok(())
310        }
311    }
312
313    pub fn respond_with_error<T: RequestMessage>(
314        self: &Arc<Self>,
315        receipt: Receipt<T>,
316        response: proto::Error,
317    ) -> impl Future<Output = Result<()>> {
318        let this = self.clone();
319        async move {
320            let mut connection = this.connection(receipt.sender_id).await?;
321            let message_id = connection
322                .next_message_id
323                .fetch_add(1, atomic::Ordering::SeqCst);
324            connection
325                .outgoing_tx
326                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
327                .await?;
328            Ok(())
329        }
330    }
331
332    fn connection(
333        self: &Arc<Self>,
334        connection_id: ConnectionId,
335    ) -> impl Future<Output = Result<Connection>> {
336        let this = self.clone();
337        async move {
338            let connections = this.connections.read().await;
339            let connection = connections
340                .get(&connection_id)
341                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
342            Ok(connection.clone())
343        }
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use crate::{test, TypedEnvelope};
351
352    #[test]
353    fn test_request_response() {
354        smol::block_on(async move {
355            // create 2 clients connected to 1 server
356            let server = Peer::new();
357            let client1 = Peer::new();
358            let client2 = Peer::new();
359
360            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
361            let (client1_conn_id, io_task1, _) =
362                client1.add_connection(client1_to_server_conn).await;
363            let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
364
365            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
366            let (client2_conn_id, io_task3, _) =
367                client2.add_connection(client2_to_server_conn).await;
368            let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
369
370            smol::spawn(io_task1).detach();
371            smol::spawn(io_task2).detach();
372            smol::spawn(io_task3).detach();
373            smol::spawn(io_task4).detach();
374            smol::spawn(handle_messages(incoming1, server.clone())).detach();
375            smol::spawn(handle_messages(incoming2, server.clone())).detach();
376
377            assert_eq!(
378                client1
379                    .request(client1_conn_id, proto::Ping { id: 1 },)
380                    .await
381                    .unwrap(),
382                proto::Pong { id: 1 }
383            );
384
385            assert_eq!(
386                client2
387                    .request(client2_conn_id, proto::Ping { id: 2 },)
388                    .await
389                    .unwrap(),
390                proto::Pong { id: 2 }
391            );
392
393            assert_eq!(
394                client1
395                    .request(
396                        client1_conn_id,
397                        proto::OpenBuffer {
398                            worktree_id: 1,
399                            path: "path/one".to_string(),
400                        },
401                    )
402                    .await
403                    .unwrap(),
404                proto::OpenBufferResponse {
405                    buffer: Some(proto::Buffer {
406                        id: 101,
407                        content: "path/one content".to_string(),
408                        history: vec![],
409                        selections: vec![],
410                    }),
411                }
412            );
413
414            assert_eq!(
415                client2
416                    .request(
417                        client2_conn_id,
418                        proto::OpenBuffer {
419                            worktree_id: 2,
420                            path: "path/two".to_string(),
421                        },
422                    )
423                    .await
424                    .unwrap(),
425                proto::OpenBufferResponse {
426                    buffer: Some(proto::Buffer {
427                        id: 102,
428                        content: "path/two content".to_string(),
429                        history: vec![],
430                        selections: vec![],
431                    }),
432                }
433            );
434
435            client1.disconnect(client1_conn_id).await;
436            client2.disconnect(client1_conn_id).await;
437
438            async fn handle_messages(
439                mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
440                peer: Arc<Peer>,
441            ) -> Result<()> {
442                while let Some(envelope) = messages.next().await {
443                    let envelope = envelope.into_any();
444                    if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
445                        let receipt = envelope.receipt();
446                        peer.respond(
447                            receipt,
448                            proto::Pong {
449                                id: envelope.payload.id,
450                            },
451                        )
452                        .await?
453                    } else if let Some(envelope) =
454                        envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
455                    {
456                        let message = &envelope.payload;
457                        let receipt = envelope.receipt();
458                        let response = match message.path.as_str() {
459                            "path/one" => {
460                                assert_eq!(message.worktree_id, 1);
461                                proto::OpenBufferResponse {
462                                    buffer: Some(proto::Buffer {
463                                        id: 101,
464                                        content: "path/one content".to_string(),
465                                        history: vec![],
466                                        selections: vec![],
467                                    }),
468                                }
469                            }
470                            "path/two" => {
471                                assert_eq!(message.worktree_id, 2);
472                                proto::OpenBufferResponse {
473                                    buffer: Some(proto::Buffer {
474                                        id: 102,
475                                        content: "path/two content".to_string(),
476                                        history: vec![],
477                                        selections: vec![],
478                                    }),
479                                }
480                            }
481                            _ => {
482                                panic!("unexpected path {}", message.path);
483                            }
484                        };
485
486                        peer.respond(receipt, response).await?
487                    } else {
488                        panic!("unknown message type");
489                    }
490                }
491
492                Ok(())
493            }
494        });
495    }
496
497    #[test]
498    fn test_disconnect() {
499        smol::block_on(async move {
500            let (client_conn, mut server_conn) = test::Channel::bidirectional();
501
502            let client = Peer::new();
503            let (connection_id, io_handler, mut incoming) =
504                client.add_connection(client_conn).await;
505
506            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
507            smol::spawn(async move {
508                io_handler.await.ok();
509                io_ended_tx.send(()).await.unwrap();
510            })
511            .detach();
512
513            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
514            smol::spawn(async move {
515                incoming.next().await;
516                messages_ended_tx.send(()).await.unwrap();
517            })
518            .detach();
519
520            client.disconnect(connection_id).await;
521
522            io_ended_rx.recv().await;
523            messages_ended_rx.recv().await;
524            assert!(
525                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
526                    .await
527                    .is_err()
528            );
529        });
530    }
531
532    #[test]
533    fn test_io_error() {
534        smol::block_on(async move {
535            let (client_conn, server_conn) = test::Channel::bidirectional();
536            drop(server_conn);
537
538            let client = Peer::new();
539            let (connection_id, io_handler, mut incoming) =
540                client.add_connection(client_conn).await;
541            smol::spawn(io_handler).detach();
542            smol::spawn(async move { incoming.next().await }).detach();
543
544            let err = client
545                .request(connection_id, proto::Ping { id: 42 })
546                .await
547                .unwrap_err();
548            assert_eq!(err.to_string(), "connection was closed");
549        });
550    }
551}