peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use futures::{channel::oneshot, stream::BoxStream, FutureExt as _, StreamExt};
  5use parking_lot::{Mutex, RwLock};
  6use postage::{
  7    barrier, mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use smol_timeout::TimeoutExt as _;
 11use std::sync::atomic::Ordering::SeqCst;
 12use std::{
 13    collections::HashMap,
 14    fmt,
 15    future::Future,
 16    marker::PhantomData,
 17    sync::{
 18        atomic::{self, AtomicU32},
 19        Arc,
 20    },
 21    time::Duration,
 22};
 23
 24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 25pub struct ConnectionId(pub u32);
 26
 27impl fmt::Display for ConnectionId {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        self.0.fmt(f)
 30    }
 31}
 32
 33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 34pub struct PeerId(pub u32);
 35
 36impl fmt::Display for PeerId {
 37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 38        self.0.fmt(f)
 39    }
 40}
 41
 42pub struct Receipt<T> {
 43    pub sender_id: ConnectionId,
 44    pub message_id: u32,
 45    payload_type: PhantomData<T>,
 46}
 47
 48impl<T> Clone for Receipt<T> {
 49    fn clone(&self) -> Self {
 50        Self {
 51            sender_id: self.sender_id,
 52            message_id: self.message_id,
 53            payload_type: PhantomData,
 54        }
 55    }
 56}
 57
 58impl<T> Copy for Receipt<T> {}
 59
 60pub struct TypedEnvelope<T> {
 61    pub sender_id: ConnectionId,
 62    pub original_sender_id: Option<PeerId>,
 63    pub message_id: u32,
 64    pub payload: T,
 65}
 66
 67impl<T> TypedEnvelope<T> {
 68    pub fn original_sender_id(&self) -> Result<PeerId> {
 69        self.original_sender_id
 70            .ok_or_else(|| anyhow!("missing original_sender_id"))
 71    }
 72}
 73
 74impl<T: RequestMessage> TypedEnvelope<T> {
 75    pub fn receipt(&self) -> Receipt<T> {
 76        Receipt {
 77            sender_id: self.sender_id,
 78            message_id: self.message_id,
 79            payload_type: PhantomData,
 80        }
 81    }
 82}
 83
 84pub struct Peer {
 85    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 86    next_connection_id: AtomicU32,
 87}
 88
 89#[derive(Clone)]
 90pub struct ConnectionState {
 91    outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Message>,
 92    next_message_id: Arc<AtomicU32>,
 93    response_channels:
 94        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 95}
 96
 97const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
 98const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
 99
100impl Peer {
101    pub fn new() -> Arc<Self> {
102        Arc::new(Self {
103            connections: Default::default(),
104            next_connection_id: Default::default(),
105        })
106    }
107
108    pub async fn add_connection<F, Fut, Out>(
109        self: &Arc<Self>,
110        connection: Connection,
111        create_timer: F,
112    ) -> (
113        ConnectionId,
114        impl Future<Output = anyhow::Result<()>> + Send,
115        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
116    )
117    where
118        F: Send + Fn(Duration) -> Fut,
119        Fut: Send + Future<Output = Out>,
120        Out: Send,
121    {
122        // For outgoing messages, use an unbounded channel so that application code
123        // can always send messages without yielding. For incoming messages, use a
124        // bounded channel so that other peers will receive backpressure if they send
125        // messages faster than this peer can process them.
126        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
127        let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
128
129        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
130        let connection_state = ConnectionState {
131            outgoing_tx: outgoing_tx.clone(),
132            next_message_id: Default::default(),
133            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
134        };
135        let mut writer = MessageStream::new(connection.tx);
136        let mut reader = MessageStream::new(connection.rx);
137
138        let this = self.clone();
139        let response_channels = connection_state.response_channels.clone();
140        let handle_io = async move {
141            let _end_connection = util::defer(|| {
142                response_channels.lock().take();
143                this.connections.write().remove(&connection_id);
144            });
145
146            // Send messages on this frequency so the connection isn't closed.
147            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
148            futures::pin_mut!(keepalive_timer);
149
150            loop {
151                let read_message = reader.read().fuse();
152                futures::pin_mut!(read_message);
153
154                // Disconnect if we don't receive messages at least this frequently.
155                let receive_timeout = create_timer(3 * KEEPALIVE_INTERVAL).fuse();
156                futures::pin_mut!(receive_timeout);
157
158                loop {
159                    futures::select_biased! {
160                        outgoing = outgoing_rx.next().fuse() => match outgoing {
161                            Some(outgoing) => {
162                                if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
163                                    result.context("failed to write RPC message")?;
164                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
165                                } else {
166                                    Err(anyhow!("timed out writing message"))?;
167                                }
168                            }
169                            None => return Ok(()),
170                        },
171                        incoming = read_message => {
172                            let incoming = incoming.context("received invalid RPC message")?;
173                            if let proto::Message::Envelope(incoming) = incoming {
174                                if incoming_tx.send(incoming).await.is_err() {
175                                    return Ok(());
176                                }
177                            }
178                            break;
179                        },
180                        _ = keepalive_timer => {
181                            if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
182                                result.context("failed to send keepalive")?;
183                                keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
184                            } else {
185                                Err(anyhow!("timed out sending keepalive"))?;
186                            }
187                        }
188                        _ = receive_timeout => {
189                            Err(anyhow!("delay between messages too long"))?
190                        }
191                    }
192                }
193            }
194        };
195
196        let response_channels = connection_state.response_channels.clone();
197        self.connections
198            .write()
199            .insert(connection_id, connection_state);
200
201        let incoming_rx = incoming_rx.filter_map(move |incoming| {
202            let response_channels = response_channels.clone();
203            async move {
204                if let Some(responding_to) = incoming.responding_to {
205                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
206                    if let Some(tx) = channel {
207                        let mut requester_resumed = barrier::channel();
208                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
209                            log::debug!(
210                                "received RPC but request future was dropped {:?}",
211                                error.0
212                            );
213                        }
214                        requester_resumed.1.recv().await;
215                    } else {
216                        log::warn!("received RPC response to unknown request {}", responding_to);
217                    }
218
219                    None
220                } else {
221                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
222                        log::error!("unable to construct a typed envelope");
223                        None
224                    })
225                }
226            }
227        });
228        (connection_id, handle_io, incoming_rx.boxed())
229    }
230
231    #[cfg(any(test, feature = "test-support"))]
232    pub async fn add_test_connection(
233        self: &Arc<Self>,
234        connection: Connection,
235        executor: Arc<gpui::executor::Background>,
236    ) -> (
237        ConnectionId,
238        impl Future<Output = anyhow::Result<()>> + Send,
239        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
240    ) {
241        let executor = executor.clone();
242        self.add_connection(connection, move |duration| executor.timer(duration))
243            .await
244    }
245
246    pub fn disconnect(&self, connection_id: ConnectionId) {
247        self.connections.write().remove(&connection_id);
248    }
249
250    pub fn reset(&self) {
251        self.connections.write().clear();
252    }
253
254    pub fn request<T: RequestMessage>(
255        &self,
256        receiver_id: ConnectionId,
257        request: T,
258    ) -> impl Future<Output = Result<T::Response>> {
259        self.request_internal(None, receiver_id, request)
260    }
261
262    pub fn forward_request<T: RequestMessage>(
263        &self,
264        sender_id: ConnectionId,
265        receiver_id: ConnectionId,
266        request: T,
267    ) -> impl Future<Output = Result<T::Response>> {
268        self.request_internal(Some(sender_id), receiver_id, request)
269    }
270
271    pub fn request_internal<T: RequestMessage>(
272        &self,
273        original_sender_id: Option<ConnectionId>,
274        receiver_id: ConnectionId,
275        request: T,
276    ) -> impl Future<Output = Result<T::Response>> {
277        let (tx, rx) = oneshot::channel();
278        let send = self.connection_state(receiver_id).and_then(|connection| {
279            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
280            connection
281                .response_channels
282                .lock()
283                .as_mut()
284                .ok_or_else(|| anyhow!("connection was closed"))?
285                .insert(message_id, tx);
286            connection
287                .outgoing_tx
288                .unbounded_send(proto::Message::Envelope(request.into_envelope(
289                    message_id,
290                    None,
291                    original_sender_id.map(|id| id.0),
292                )))
293                .map_err(|_| anyhow!("connection was closed"))?;
294            Ok(())
295        });
296        async move {
297            send?;
298            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
299            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
300                Err(anyhow!("RPC request failed - {}", error.message))
301            } else {
302                T::Response::from_envelope(response)
303                    .ok_or_else(|| anyhow!("received response of the wrong type"))
304            }
305        }
306    }
307
308    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
309        let connection = self.connection_state(receiver_id)?;
310        let message_id = connection
311            .next_message_id
312            .fetch_add(1, atomic::Ordering::SeqCst);
313        connection
314            .outgoing_tx
315            .unbounded_send(proto::Message::Envelope(
316                message.into_envelope(message_id, None, None),
317            ))?;
318        Ok(())
319    }
320
321    pub fn forward_send<T: EnvelopedMessage>(
322        &self,
323        sender_id: ConnectionId,
324        receiver_id: ConnectionId,
325        message: T,
326    ) -> Result<()> {
327        let connection = self.connection_state(receiver_id)?;
328        let message_id = connection
329            .next_message_id
330            .fetch_add(1, atomic::Ordering::SeqCst);
331        connection
332            .outgoing_tx
333            .unbounded_send(proto::Message::Envelope(message.into_envelope(
334                message_id,
335                None,
336                Some(sender_id.0),
337            )))?;
338        Ok(())
339    }
340
341    pub fn respond<T: RequestMessage>(
342        &self,
343        receipt: Receipt<T>,
344        response: T::Response,
345    ) -> Result<()> {
346        let connection = self.connection_state(receipt.sender_id)?;
347        let message_id = connection
348            .next_message_id
349            .fetch_add(1, atomic::Ordering::SeqCst);
350        connection
351            .outgoing_tx
352            .unbounded_send(proto::Message::Envelope(response.into_envelope(
353                message_id,
354                Some(receipt.message_id),
355                None,
356            )))?;
357        Ok(())
358    }
359
360    pub fn respond_with_error<T: RequestMessage>(
361        &self,
362        receipt: Receipt<T>,
363        response: proto::Error,
364    ) -> Result<()> {
365        let connection = self.connection_state(receipt.sender_id)?;
366        let message_id = connection
367            .next_message_id
368            .fetch_add(1, atomic::Ordering::SeqCst);
369        connection
370            .outgoing_tx
371            .unbounded_send(proto::Message::Envelope(response.into_envelope(
372                message_id,
373                Some(receipt.message_id),
374                None,
375            )))?;
376        Ok(())
377    }
378
379    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
380        let connections = self.connections.read();
381        let connection = connections
382            .get(&connection_id)
383            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
384        Ok(connection.clone())
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use crate::TypedEnvelope;
392    use async_tungstenite::tungstenite::Message as WebSocketMessage;
393    use gpui::TestAppContext;
394
395    #[gpui::test(iterations = 50)]
396    async fn test_request_response(cx: &mut TestAppContext) {
397        let executor = cx.foreground();
398
399        // create 2 clients connected to 1 server
400        let server = Peer::new();
401        let client1 = Peer::new();
402        let client2 = Peer::new();
403
404        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
405            Connection::in_memory(cx.background());
406        let (client1_conn_id, io_task1, client1_incoming) = client1
407            .add_test_connection(client1_to_server_conn, cx.background())
408            .await;
409        let (_, io_task2, server_incoming1) = server
410            .add_test_connection(server_to_client_1_conn, cx.background())
411            .await;
412
413        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
414            Connection::in_memory(cx.background());
415        let (client2_conn_id, io_task3, client2_incoming) = client2
416            .add_test_connection(client2_to_server_conn, cx.background())
417            .await;
418        let (_, io_task4, server_incoming2) = server
419            .add_test_connection(server_to_client_2_conn, cx.background())
420            .await;
421
422        executor.spawn(io_task1).detach();
423        executor.spawn(io_task2).detach();
424        executor.spawn(io_task3).detach();
425        executor.spawn(io_task4).detach();
426        executor
427            .spawn(handle_messages(server_incoming1, server.clone()))
428            .detach();
429        executor
430            .spawn(handle_messages(client1_incoming, client1.clone()))
431            .detach();
432        executor
433            .spawn(handle_messages(server_incoming2, server.clone()))
434            .detach();
435        executor
436            .spawn(handle_messages(client2_incoming, client2.clone()))
437            .detach();
438
439        assert_eq!(
440            client1
441                .request(client1_conn_id, proto::Ping {},)
442                .await
443                .unwrap(),
444            proto::Ack {}
445        );
446
447        assert_eq!(
448            client2
449                .request(client2_conn_id, proto::Ping {},)
450                .await
451                .unwrap(),
452            proto::Ack {}
453        );
454
455        assert_eq!(
456            client1
457                .request(client1_conn_id, proto::Test { id: 1 },)
458                .await
459                .unwrap(),
460            proto::Test { id: 1 }
461        );
462
463        assert_eq!(
464            client2
465                .request(client2_conn_id, proto::Test { id: 2 })
466                .await
467                .unwrap(),
468            proto::Test { id: 2 }
469        );
470
471        client1.disconnect(client1_conn_id);
472        client2.disconnect(client1_conn_id);
473
474        async fn handle_messages(
475            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
476            peer: Arc<Peer>,
477        ) -> Result<()> {
478            while let Some(envelope) = messages.next().await {
479                let envelope = envelope.into_any();
480                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
481                    let receipt = envelope.receipt();
482                    peer.respond(receipt, proto::Ack {})?
483                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
484                {
485                    peer.respond(envelope.receipt(), envelope.payload.clone())?
486                } else {
487                    panic!("unknown message type");
488                }
489            }
490
491            Ok(())
492        }
493    }
494
495    #[gpui::test(iterations = 50)]
496    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
497        let executor = cx.foreground();
498        let server = Peer::new();
499        let client = Peer::new();
500
501        let (client_to_server_conn, server_to_client_conn, _kill) =
502            Connection::in_memory(cx.background());
503        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
504            .add_test_connection(client_to_server_conn, cx.background())
505            .await;
506        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
507            .add_test_connection(server_to_client_conn, cx.background())
508            .await;
509
510        executor.spawn(io_task1).detach();
511        executor.spawn(io_task2).detach();
512
513        executor
514            .spawn(async move {
515                let request = server_incoming
516                    .next()
517                    .await
518                    .unwrap()
519                    .into_any()
520                    .downcast::<TypedEnvelope<proto::Ping>>()
521                    .unwrap();
522
523                server
524                    .send(
525                        server_to_client_conn_id,
526                        proto::Error {
527                            message: "message 1".to_string(),
528                        },
529                    )
530                    .unwrap();
531                server
532                    .send(
533                        server_to_client_conn_id,
534                        proto::Error {
535                            message: "message 2".to_string(),
536                        },
537                    )
538                    .unwrap();
539                server.respond(request.receipt(), proto::Ack {}).unwrap();
540
541                // Prevent the connection from being dropped
542                server_incoming.next().await;
543            })
544            .detach();
545
546        let events = Arc::new(Mutex::new(Vec::new()));
547
548        let response = client.request(client_to_server_conn_id, proto::Ping {});
549        let response_task = executor.spawn({
550            let events = events.clone();
551            async move {
552                response.await.unwrap();
553                events.lock().push("response".to_string());
554            }
555        });
556
557        executor
558            .spawn({
559                let events = events.clone();
560                async move {
561                    let incoming1 = client_incoming
562                        .next()
563                        .await
564                        .unwrap()
565                        .into_any()
566                        .downcast::<TypedEnvelope<proto::Error>>()
567                        .unwrap();
568                    events.lock().push(incoming1.payload.message);
569                    let incoming2 = client_incoming
570                        .next()
571                        .await
572                        .unwrap()
573                        .into_any()
574                        .downcast::<TypedEnvelope<proto::Error>>()
575                        .unwrap();
576                    events.lock().push(incoming2.payload.message);
577
578                    // Prevent the connection from being dropped
579                    client_incoming.next().await;
580                }
581            })
582            .detach();
583
584        response_task.await;
585        assert_eq!(
586            &*events.lock(),
587            &[
588                "message 1".to_string(),
589                "message 2".to_string(),
590                "response".to_string()
591            ]
592        );
593    }
594
595    #[gpui::test(iterations = 50)]
596    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
597        let executor = cx.foreground();
598        let server = Peer::new();
599        let client = Peer::new();
600
601        let (client_to_server_conn, server_to_client_conn, _kill) =
602            Connection::in_memory(cx.background());
603        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
604            .add_test_connection(client_to_server_conn, cx.background())
605            .await;
606        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
607            .add_test_connection(server_to_client_conn, cx.background())
608            .await;
609
610        executor.spawn(io_task1).detach();
611        executor.spawn(io_task2).detach();
612
613        executor
614            .spawn(async move {
615                let request1 = server_incoming
616                    .next()
617                    .await
618                    .unwrap()
619                    .into_any()
620                    .downcast::<TypedEnvelope<proto::Ping>>()
621                    .unwrap();
622                let request2 = server_incoming
623                    .next()
624                    .await
625                    .unwrap()
626                    .into_any()
627                    .downcast::<TypedEnvelope<proto::Ping>>()
628                    .unwrap();
629
630                server
631                    .send(
632                        server_to_client_conn_id,
633                        proto::Error {
634                            message: "message 1".to_string(),
635                        },
636                    )
637                    .unwrap();
638                server
639                    .send(
640                        server_to_client_conn_id,
641                        proto::Error {
642                            message: "message 2".to_string(),
643                        },
644                    )
645                    .unwrap();
646                server.respond(request1.receipt(), proto::Ack {}).unwrap();
647                server.respond(request2.receipt(), proto::Ack {}).unwrap();
648
649                // Prevent the connection from being dropped
650                server_incoming.next().await;
651            })
652            .detach();
653
654        let events = Arc::new(Mutex::new(Vec::new()));
655
656        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
657        let request1_task = executor.spawn(request1);
658        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
659        let request2_task = executor.spawn({
660            let events = events.clone();
661            async move {
662                request2.await.unwrap();
663                events.lock().push("response 2".to_string());
664            }
665        });
666
667        executor
668            .spawn({
669                let events = events.clone();
670                async move {
671                    let incoming1 = client_incoming
672                        .next()
673                        .await
674                        .unwrap()
675                        .into_any()
676                        .downcast::<TypedEnvelope<proto::Error>>()
677                        .unwrap();
678                    events.lock().push(incoming1.payload.message);
679                    let incoming2 = client_incoming
680                        .next()
681                        .await
682                        .unwrap()
683                        .into_any()
684                        .downcast::<TypedEnvelope<proto::Error>>()
685                        .unwrap();
686                    events.lock().push(incoming2.payload.message);
687
688                    // Prevent the connection from being dropped
689                    client_incoming.next().await;
690                }
691            })
692            .detach();
693
694        // Allow the request to make some progress before dropping it.
695        cx.background().simulate_random_delay().await;
696        drop(request1_task);
697
698        request2_task.await;
699        assert_eq!(
700            &*events.lock(),
701            &[
702                "message 1".to_string(),
703                "message 2".to_string(),
704                "response 2".to_string()
705            ]
706        );
707    }
708
709    #[gpui::test(iterations = 50)]
710    async fn test_disconnect(cx: &mut TestAppContext) {
711        let executor = cx.foreground();
712
713        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
714
715        let client = Peer::new();
716        let (connection_id, io_handler, mut incoming) = client
717            .add_test_connection(client_conn, cx.background())
718            .await;
719
720        let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
721        executor
722            .spawn(async move {
723                io_handler.await.ok();
724                io_ended_tx.send(()).await.unwrap();
725            })
726            .detach();
727
728        let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
729        executor
730            .spawn(async move {
731                incoming.next().await;
732                messages_ended_tx.send(()).await.unwrap();
733            })
734            .detach();
735
736        client.disconnect(connection_id);
737
738        io_ended_rx.recv().await;
739        messages_ended_rx.recv().await;
740        assert!(server_conn
741            .send(WebSocketMessage::Binary(vec![]))
742            .await
743            .is_err());
744    }
745
746    #[gpui::test(iterations = 50)]
747    async fn test_io_error(cx: &mut TestAppContext) {
748        let executor = cx.foreground();
749        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
750
751        let client = Peer::new();
752        let (connection_id, io_handler, mut incoming) = client
753            .add_test_connection(client_conn, cx.background())
754            .await;
755        executor.spawn(io_handler).detach();
756        executor
757            .spawn(async move { incoming.next().await })
758            .detach();
759
760        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
761        let _request = server_conn.rx.next().await.unwrap().unwrap();
762
763        drop(server_conn);
764        assert_eq!(
765            response.await.unwrap_err().to_string(),
766            "connection was closed"
767        );
768    }
769}