peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use futures::{channel::oneshot, stream::BoxStream, FutureExt as _, StreamExt};
  5use parking_lot::{Mutex, RwLock};
  6use postage::{
  7    barrier, mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use smol_timeout::TimeoutExt as _;
 11use std::sync::atomic::Ordering::SeqCst;
 12use std::{
 13    collections::HashMap,
 14    fmt,
 15    future::Future,
 16    marker::PhantomData,
 17    sync::{
 18        atomic::{self, AtomicU32},
 19        Arc,
 20    },
 21    time::Duration,
 22};
 23
 24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 25pub struct ConnectionId(pub u32);
 26
 27impl fmt::Display for ConnectionId {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        self.0.fmt(f)
 30    }
 31}
 32
 33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 34pub struct PeerId(pub u32);
 35
 36impl fmt::Display for PeerId {
 37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 38        self.0.fmt(f)
 39    }
 40}
 41
 42pub struct Receipt<T> {
 43    pub sender_id: ConnectionId,
 44    pub message_id: u32,
 45    payload_type: PhantomData<T>,
 46}
 47
 48impl<T> Clone for Receipt<T> {
 49    fn clone(&self) -> Self {
 50        Self {
 51            sender_id: self.sender_id,
 52            message_id: self.message_id,
 53            payload_type: PhantomData,
 54        }
 55    }
 56}
 57
 58impl<T> Copy for Receipt<T> {}
 59
 60pub struct TypedEnvelope<T> {
 61    pub sender_id: ConnectionId,
 62    pub original_sender_id: Option<PeerId>,
 63    pub message_id: u32,
 64    pub payload: T,
 65}
 66
 67impl<T> TypedEnvelope<T> {
 68    pub fn original_sender_id(&self) -> Result<PeerId> {
 69        self.original_sender_id
 70            .ok_or_else(|| anyhow!("missing original_sender_id"))
 71    }
 72}
 73
 74impl<T: RequestMessage> TypedEnvelope<T> {
 75    pub fn receipt(&self) -> Receipt<T> {
 76        Receipt {
 77            sender_id: self.sender_id,
 78            message_id: self.message_id,
 79            payload_type: PhantomData,
 80        }
 81    }
 82}
 83
 84pub struct Peer {
 85    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 86    next_connection_id: AtomicU32,
 87}
 88
 89#[derive(Clone)]
 90pub struct ConnectionState {
 91    outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
 92    next_message_id: Arc<AtomicU32>,
 93    response_channels:
 94        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 95}
 96
 97const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
 98
 99impl Peer {
100    pub fn new() -> Arc<Self> {
101        Arc::new(Self {
102            connections: Default::default(),
103            next_connection_id: Default::default(),
104        })
105    }
106
107    pub async fn add_connection<F, Fut, Out>(
108        self: &Arc<Self>,
109        connection: Connection,
110        create_timer: F,
111    ) -> (
112        ConnectionId,
113        impl Future<Output = anyhow::Result<()>> + Send,
114        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
115    )
116    where
117        F: Send + Fn(Duration) -> Fut,
118        Fut: Send + Future<Output = Out>,
119        Out: Send,
120    {
121        // For outgoing messages, use an unbounded channel so that application code
122        // can always send messages without yielding. For incoming messages, use a
123        // bounded channel so that other peers will receive backpressure if they send
124        // messages faster than this peer can process them.
125        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
126        let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
127
128        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
129        let connection_state = ConnectionState {
130            outgoing_tx: outgoing_tx.clone(),
131            next_message_id: Default::default(),
132            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
133        };
134        let mut writer = MessageStream::new(connection.tx);
135        let mut reader = MessageStream::new(connection.rx);
136
137        let this = self.clone();
138        let response_channels = connection_state.response_channels.clone();
139        let handle_io = async move {
140            let _end_connection = util::defer(|| {
141                response_channels.lock().take();
142                this.connections.write().remove(&connection_id);
143            });
144
145            loop {
146                let read_message = reader.read().fuse();
147                futures::pin_mut!(read_message);
148                let read_timeout = create_timer(2 * KEEPALIVE_INTERVAL).fuse();
149                futures::pin_mut!(read_timeout);
150
151                loop {
152                    futures::select_biased! {
153                        outgoing = outgoing_rx.next().fuse() => match outgoing {
154                            Some(outgoing) => {
155                                let outgoing = proto::Message::Envelope(outgoing);
156                                if let Some(result) = writer.write(outgoing).timeout(2 * KEEPALIVE_INTERVAL).await {
157                                    result.context("failed to write RPC message")?;
158                                } else {
159                                    Err(anyhow!("timed out writing message"))?;
160                                }
161                            }
162                            None => return Ok(()),
163                        },
164                        incoming = read_message => {
165                            let incoming = incoming.context("received invalid RPC message")?;
166                            if let proto::Message::Envelope(incoming) = incoming {
167                                if incoming_tx.send(incoming).await.is_err() {
168                                    return Ok(());
169                                }
170                            }
171
172                            break;
173                        },
174                        _ = create_timer(KEEPALIVE_INTERVAL).fuse() => {
175                            if let Some(result) = writer.write(proto::Message::Ping).timeout(2 * KEEPALIVE_INTERVAL).await {
176                                result.context("failed to send websocket ping")?;
177                            } else {
178                                Err(anyhow!("timed out sending websocket ping"))?;
179                            }
180                        }
181                        _ = read_timeout => {
182                            Err(anyhow!("timed out reading message"))?
183                        }
184                    }
185                }
186            }
187        };
188
189        let response_channels = connection_state.response_channels.clone();
190        self.connections
191            .write()
192            .insert(connection_id, connection_state);
193
194        let incoming_rx = incoming_rx.filter_map(move |incoming| {
195            let response_channels = response_channels.clone();
196            async move {
197                if let Some(responding_to) = incoming.responding_to {
198                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
199                    if let Some(tx) = channel {
200                        let mut requester_resumed = barrier::channel();
201                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
202                            log::debug!(
203                                "received RPC but request future was dropped {:?}",
204                                error.0
205                            );
206                        }
207                        requester_resumed.1.recv().await;
208                    } else {
209                        log::warn!("received RPC response to unknown request {}", responding_to);
210                    }
211
212                    None
213                } else {
214                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
215                        log::error!("unable to construct a typed envelope");
216                        None
217                    })
218                }
219            }
220        });
221        (connection_id, handle_io, incoming_rx.boxed())
222    }
223
224    #[cfg(any(test, feature = "test-support"))]
225    pub async fn add_test_connection(
226        self: &Arc<Self>,
227        connection: Connection,
228        executor: Arc<gpui::executor::Background>,
229    ) -> (
230        ConnectionId,
231        impl Future<Output = anyhow::Result<()>> + Send,
232        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
233    ) {
234        let executor = executor.clone();
235        self.add_connection(connection, move |duration| executor.timer(duration))
236            .await
237    }
238
239    pub fn disconnect(&self, connection_id: ConnectionId) {
240        self.connections.write().remove(&connection_id);
241    }
242
243    pub fn reset(&self) {
244        self.connections.write().clear();
245    }
246
247    pub fn request<T: RequestMessage>(
248        &self,
249        receiver_id: ConnectionId,
250        request: T,
251    ) -> impl Future<Output = Result<T::Response>> {
252        self.request_internal(None, receiver_id, request)
253    }
254
255    pub fn forward_request<T: RequestMessage>(
256        &self,
257        sender_id: ConnectionId,
258        receiver_id: ConnectionId,
259        request: T,
260    ) -> impl Future<Output = Result<T::Response>> {
261        self.request_internal(Some(sender_id), receiver_id, request)
262    }
263
264    pub fn request_internal<T: RequestMessage>(
265        &self,
266        original_sender_id: Option<ConnectionId>,
267        receiver_id: ConnectionId,
268        request: T,
269    ) -> impl Future<Output = Result<T::Response>> {
270        let (tx, rx) = oneshot::channel();
271        let send = self.connection_state(receiver_id).and_then(|connection| {
272            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
273            connection
274                .response_channels
275                .lock()
276                .as_mut()
277                .ok_or_else(|| anyhow!("connection was closed"))?
278                .insert(message_id, tx);
279            connection
280                .outgoing_tx
281                .unbounded_send(request.into_envelope(
282                    message_id,
283                    None,
284                    original_sender_id.map(|id| id.0),
285                ))
286                .map_err(|_| anyhow!("connection was closed"))?;
287            Ok(())
288        });
289        async move {
290            send?;
291            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
292            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
293                Err(anyhow!("RPC request failed - {}", error.message))
294            } else {
295                T::Response::from_envelope(response)
296                    .ok_or_else(|| anyhow!("received response of the wrong type"))
297            }
298        }
299    }
300
301    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
302        let connection = self.connection_state(receiver_id)?;
303        let message_id = connection
304            .next_message_id
305            .fetch_add(1, atomic::Ordering::SeqCst);
306        connection
307            .outgoing_tx
308            .unbounded_send(message.into_envelope(message_id, None, None))?;
309        Ok(())
310    }
311
312    pub fn forward_send<T: EnvelopedMessage>(
313        &self,
314        sender_id: ConnectionId,
315        receiver_id: ConnectionId,
316        message: T,
317    ) -> Result<()> {
318        let connection = self.connection_state(receiver_id)?;
319        let message_id = connection
320            .next_message_id
321            .fetch_add(1, atomic::Ordering::SeqCst);
322        connection
323            .outgoing_tx
324            .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
325        Ok(())
326    }
327
328    pub fn respond<T: RequestMessage>(
329        &self,
330        receipt: Receipt<T>,
331        response: T::Response,
332    ) -> Result<()> {
333        let connection = self.connection_state(receipt.sender_id)?;
334        let message_id = connection
335            .next_message_id
336            .fetch_add(1, atomic::Ordering::SeqCst);
337        connection
338            .outgoing_tx
339            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
340        Ok(())
341    }
342
343    pub fn respond_with_error<T: RequestMessage>(
344        &self,
345        receipt: Receipt<T>,
346        response: proto::Error,
347    ) -> Result<()> {
348        let connection = self.connection_state(receipt.sender_id)?;
349        let message_id = connection
350            .next_message_id
351            .fetch_add(1, atomic::Ordering::SeqCst);
352        connection
353            .outgoing_tx
354            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
355        Ok(())
356    }
357
358    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
359        let connections = self.connections.read();
360        let connection = connections
361            .get(&connection_id)
362            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
363        Ok(connection.clone())
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::TypedEnvelope;
371    use async_tungstenite::tungstenite::Message as WebSocketMessage;
372    use gpui::TestAppContext;
373
374    #[gpui::test(iterations = 50)]
375    async fn test_request_response(cx: &mut TestAppContext) {
376        let executor = cx.foreground();
377
378        // create 2 clients connected to 1 server
379        let server = Peer::new();
380        let client1 = Peer::new();
381        let client2 = Peer::new();
382
383        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
384            Connection::in_memory(cx.background());
385        let (client1_conn_id, io_task1, client1_incoming) = client1
386            .add_test_connection(client1_to_server_conn, cx.background())
387            .await;
388        let (_, io_task2, server_incoming1) = server
389            .add_test_connection(server_to_client_1_conn, cx.background())
390            .await;
391
392        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
393            Connection::in_memory(cx.background());
394        let (client2_conn_id, io_task3, client2_incoming) = client2
395            .add_test_connection(client2_to_server_conn, cx.background())
396            .await;
397        let (_, io_task4, server_incoming2) = server
398            .add_test_connection(server_to_client_2_conn, cx.background())
399            .await;
400
401        executor.spawn(io_task1).detach();
402        executor.spawn(io_task2).detach();
403        executor.spawn(io_task3).detach();
404        executor.spawn(io_task4).detach();
405        executor
406            .spawn(handle_messages(server_incoming1, server.clone()))
407            .detach();
408        executor
409            .spawn(handle_messages(client1_incoming, client1.clone()))
410            .detach();
411        executor
412            .spawn(handle_messages(server_incoming2, server.clone()))
413            .detach();
414        executor
415            .spawn(handle_messages(client2_incoming, client2.clone()))
416            .detach();
417
418        assert_eq!(
419            client1
420                .request(client1_conn_id, proto::Ping {},)
421                .await
422                .unwrap(),
423            proto::Ack {}
424        );
425
426        assert_eq!(
427            client2
428                .request(client2_conn_id, proto::Ping {},)
429                .await
430                .unwrap(),
431            proto::Ack {}
432        );
433
434        assert_eq!(
435            client1
436                .request(client1_conn_id, proto::Test { id: 1 },)
437                .await
438                .unwrap(),
439            proto::Test { id: 1 }
440        );
441
442        assert_eq!(
443            client2
444                .request(client2_conn_id, proto::Test { id: 2 })
445                .await
446                .unwrap(),
447            proto::Test { id: 2 }
448        );
449
450        client1.disconnect(client1_conn_id);
451        client2.disconnect(client1_conn_id);
452
453        async fn handle_messages(
454            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
455            peer: Arc<Peer>,
456        ) -> Result<()> {
457            while let Some(envelope) = messages.next().await {
458                let envelope = envelope.into_any();
459                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
460                    let receipt = envelope.receipt();
461                    peer.respond(receipt, proto::Ack {})?
462                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
463                {
464                    peer.respond(envelope.receipt(), envelope.payload.clone())?
465                } else {
466                    panic!("unknown message type");
467                }
468            }
469
470            Ok(())
471        }
472    }
473
474    #[gpui::test(iterations = 50)]
475    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
476        let executor = cx.foreground();
477        let server = Peer::new();
478        let client = Peer::new();
479
480        let (client_to_server_conn, server_to_client_conn, _kill) =
481            Connection::in_memory(cx.background());
482        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
483            .add_test_connection(client_to_server_conn, cx.background())
484            .await;
485        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
486            .add_test_connection(server_to_client_conn, cx.background())
487            .await;
488
489        executor.spawn(io_task1).detach();
490        executor.spawn(io_task2).detach();
491
492        executor
493            .spawn(async move {
494                let request = server_incoming
495                    .next()
496                    .await
497                    .unwrap()
498                    .into_any()
499                    .downcast::<TypedEnvelope<proto::Ping>>()
500                    .unwrap();
501
502                server
503                    .send(
504                        server_to_client_conn_id,
505                        proto::Error {
506                            message: "message 1".to_string(),
507                        },
508                    )
509                    .unwrap();
510                server
511                    .send(
512                        server_to_client_conn_id,
513                        proto::Error {
514                            message: "message 2".to_string(),
515                        },
516                    )
517                    .unwrap();
518                server.respond(request.receipt(), proto::Ack {}).unwrap();
519
520                // Prevent the connection from being dropped
521                server_incoming.next().await;
522            })
523            .detach();
524
525        let events = Arc::new(Mutex::new(Vec::new()));
526
527        let response = client.request(client_to_server_conn_id, proto::Ping {});
528        let response_task = executor.spawn({
529            let events = events.clone();
530            async move {
531                response.await.unwrap();
532                events.lock().push("response".to_string());
533            }
534        });
535
536        executor
537            .spawn({
538                let events = events.clone();
539                async move {
540                    let incoming1 = client_incoming
541                        .next()
542                        .await
543                        .unwrap()
544                        .into_any()
545                        .downcast::<TypedEnvelope<proto::Error>>()
546                        .unwrap();
547                    events.lock().push(incoming1.payload.message);
548                    let incoming2 = client_incoming
549                        .next()
550                        .await
551                        .unwrap()
552                        .into_any()
553                        .downcast::<TypedEnvelope<proto::Error>>()
554                        .unwrap();
555                    events.lock().push(incoming2.payload.message);
556
557                    // Prevent the connection from being dropped
558                    client_incoming.next().await;
559                }
560            })
561            .detach();
562
563        response_task.await;
564        assert_eq!(
565            &*events.lock(),
566            &[
567                "message 1".to_string(),
568                "message 2".to_string(),
569                "response".to_string()
570            ]
571        );
572    }
573
574    #[gpui::test(iterations = 50)]
575    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
576        let executor = cx.foreground();
577        let server = Peer::new();
578        let client = Peer::new();
579
580        let (client_to_server_conn, server_to_client_conn, _kill) =
581            Connection::in_memory(cx.background());
582        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
583            .add_test_connection(client_to_server_conn, cx.background())
584            .await;
585        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
586            .add_test_connection(server_to_client_conn, cx.background())
587            .await;
588
589        executor.spawn(io_task1).detach();
590        executor.spawn(io_task2).detach();
591
592        executor
593            .spawn(async move {
594                let request1 = server_incoming
595                    .next()
596                    .await
597                    .unwrap()
598                    .into_any()
599                    .downcast::<TypedEnvelope<proto::Ping>>()
600                    .unwrap();
601                let request2 = server_incoming
602                    .next()
603                    .await
604                    .unwrap()
605                    .into_any()
606                    .downcast::<TypedEnvelope<proto::Ping>>()
607                    .unwrap();
608
609                server
610                    .send(
611                        server_to_client_conn_id,
612                        proto::Error {
613                            message: "message 1".to_string(),
614                        },
615                    )
616                    .unwrap();
617                server
618                    .send(
619                        server_to_client_conn_id,
620                        proto::Error {
621                            message: "message 2".to_string(),
622                        },
623                    )
624                    .unwrap();
625                server.respond(request1.receipt(), proto::Ack {}).unwrap();
626                server.respond(request2.receipt(), proto::Ack {}).unwrap();
627
628                // Prevent the connection from being dropped
629                server_incoming.next().await;
630            })
631            .detach();
632
633        let events = Arc::new(Mutex::new(Vec::new()));
634
635        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
636        let request1_task = executor.spawn(request1);
637        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
638        let request2_task = executor.spawn({
639            let events = events.clone();
640            async move {
641                request2.await.unwrap();
642                events.lock().push("response 2".to_string());
643            }
644        });
645
646        executor
647            .spawn({
648                let events = events.clone();
649                async move {
650                    let incoming1 = client_incoming
651                        .next()
652                        .await
653                        .unwrap()
654                        .into_any()
655                        .downcast::<TypedEnvelope<proto::Error>>()
656                        .unwrap();
657                    events.lock().push(incoming1.payload.message);
658                    let incoming2 = client_incoming
659                        .next()
660                        .await
661                        .unwrap()
662                        .into_any()
663                        .downcast::<TypedEnvelope<proto::Error>>()
664                        .unwrap();
665                    events.lock().push(incoming2.payload.message);
666
667                    // Prevent the connection from being dropped
668                    client_incoming.next().await;
669                }
670            })
671            .detach();
672
673        // Allow the request to make some progress before dropping it.
674        cx.background().simulate_random_delay().await;
675        drop(request1_task);
676
677        request2_task.await;
678        assert_eq!(
679            &*events.lock(),
680            &[
681                "message 1".to_string(),
682                "message 2".to_string(),
683                "response 2".to_string()
684            ]
685        );
686    }
687
688    #[gpui::test(iterations = 50)]
689    async fn test_disconnect(cx: &mut TestAppContext) {
690        let executor = cx.foreground();
691
692        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
693
694        let client = Peer::new();
695        let (connection_id, io_handler, mut incoming) = client
696            .add_test_connection(client_conn, cx.background())
697            .await;
698
699        let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
700        executor
701            .spawn(async move {
702                io_handler.await.ok();
703                io_ended_tx.send(()).await.unwrap();
704            })
705            .detach();
706
707        let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
708        executor
709            .spawn(async move {
710                incoming.next().await;
711                messages_ended_tx.send(()).await.unwrap();
712            })
713            .detach();
714
715        client.disconnect(connection_id);
716
717        io_ended_rx.recv().await;
718        messages_ended_rx.recv().await;
719        assert!(server_conn
720            .send(WebSocketMessage::Binary(vec![]))
721            .await
722            .is_err());
723    }
724
725    #[gpui::test(iterations = 50)]
726    async fn test_io_error(cx: &mut TestAppContext) {
727        let executor = cx.foreground();
728        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
729
730        let client = Peer::new();
731        let (connection_id, io_handler, mut incoming) = client
732            .add_test_connection(client_conn, cx.background())
733            .await;
734        executor.spawn(io_handler).detach();
735        executor
736            .spawn(async move { incoming.next().await })
737            .detach();
738
739        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
740        let _request = server_conn.rx.next().await.unwrap().unwrap();
741
742        drop(server_conn);
743        assert_eq!(
744            response.await.unwrap_err().to_string(),
745            "connection was closed"
746        );
747    }
748}