peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use futures::{channel::oneshot, stream::BoxStream, FutureExt as _, StreamExt};
  5use parking_lot::{Mutex, RwLock};
  6use postage::{
  7    barrier, mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use smol_timeout::TimeoutExt as _;
 11use std::sync::atomic::Ordering::SeqCst;
 12use std::{
 13    collections::HashMap,
 14    fmt,
 15    future::Future,
 16    marker::PhantomData,
 17    sync::{
 18        atomic::{self, AtomicU32},
 19        Arc,
 20    },
 21    time::Duration,
 22};
 23
 24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 25pub struct ConnectionId(pub u32);
 26
 27impl fmt::Display for ConnectionId {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        self.0.fmt(f)
 30    }
 31}
 32
 33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 34pub struct PeerId(pub u32);
 35
 36impl fmt::Display for PeerId {
 37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 38        self.0.fmt(f)
 39    }
 40}
 41
 42pub struct Receipt<T> {
 43    pub sender_id: ConnectionId,
 44    pub message_id: u32,
 45    payload_type: PhantomData<T>,
 46}
 47
 48impl<T> Clone for Receipt<T> {
 49    fn clone(&self) -> Self {
 50        Self {
 51            sender_id: self.sender_id,
 52            message_id: self.message_id,
 53            payload_type: PhantomData,
 54        }
 55    }
 56}
 57
 58impl<T> Copy for Receipt<T> {}
 59
 60pub struct TypedEnvelope<T> {
 61    pub sender_id: ConnectionId,
 62    pub original_sender_id: Option<PeerId>,
 63    pub message_id: u32,
 64    pub payload: T,
 65}
 66
 67impl<T> TypedEnvelope<T> {
 68    pub fn original_sender_id(&self) -> Result<PeerId> {
 69        self.original_sender_id
 70            .ok_or_else(|| anyhow!("missing original_sender_id"))
 71    }
 72}
 73
 74impl<T: RequestMessage> TypedEnvelope<T> {
 75    pub fn receipt(&self) -> Receipt<T> {
 76        Receipt {
 77            sender_id: self.sender_id,
 78            message_id: self.message_id,
 79            payload_type: PhantomData,
 80        }
 81    }
 82}
 83
 84pub struct Peer {
 85    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 86    next_connection_id: AtomicU32,
 87}
 88
 89#[derive(Clone)]
 90pub struct ConnectionState {
 91    outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
 92    next_message_id: Arc<AtomicU32>,
 93    response_channels:
 94        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 95}
 96
 97const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2);
 98const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 99
100impl Peer {
101    pub fn new() -> Arc<Self> {
102        Arc::new(Self {
103            connections: Default::default(),
104            next_connection_id: Default::default(),
105        })
106    }
107
108    pub async fn add_connection<F, Fut, Out>(
109        self: &Arc<Self>,
110        connection: Connection,
111        create_timer: F,
112    ) -> (
113        ConnectionId,
114        impl Future<Output = anyhow::Result<()>> + Send,
115        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
116    )
117    where
118        F: Send + Fn(Duration) -> Fut,
119        Fut: Send + Future<Output = Out>,
120        Out: Send,
121    {
122        // For outgoing messages, use an unbounded channel so that application code
123        // can always send messages without yielding. For incoming messages, use a
124        // bounded channel so that other peers will receive backpressure if they send
125        // messages faster than this peer can process them.
126        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
127        let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
128
129        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
130        let connection_state = ConnectionState {
131            outgoing_tx: outgoing_tx.clone(),
132            next_message_id: Default::default(),
133            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
134        };
135        let mut writer = MessageStream::new(connection.tx);
136        let mut reader = MessageStream::new(connection.rx);
137
138        let this = self.clone();
139        let response_channels = connection_state.response_channels.clone();
140        let handle_io = async move {
141            let _end_connection = util::defer(|| {
142                response_channels.lock().take();
143                this.connections.write().remove(&connection_id);
144            });
145
146            loop {
147                let read_message = reader.read_message().fuse();
148                futures::pin_mut!(read_message);
149                loop {
150                    futures::select_biased! {
151                        outgoing = outgoing_rx.next().fuse() => match outgoing {
152                            Some(outgoing) => {
153                                if let Some(result) = writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
154                                    result.context("failed to write RPC message")?;
155                                } else {
156                                    Err(anyhow!("timed out writing message"))?;
157                                }
158                            }
159                            None => return Ok(()),
160                        },
161                        incoming = read_message => {
162                            let incoming = incoming.context("received invalid rpc message")?;
163                            if incoming_tx.send(incoming).await.is_err() {
164                                return Ok(());
165                            }
166                            break;
167                        },
168                        _ = create_timer(KEEPALIVE_INTERVAL).fuse() => {
169                            if let Some(result) = writer.ping().timeout(WRITE_TIMEOUT).await {
170                                result.context("failed to send websocket ping")?;
171                            } else {
172                                Err(anyhow!("timed out sending websocket ping"))?;
173                            }
174                        }
175                    }
176                }
177            }
178        };
179
180        let response_channels = connection_state.response_channels.clone();
181        self.connections
182            .write()
183            .insert(connection_id, connection_state);
184
185        let incoming_rx = incoming_rx.filter_map(move |incoming| {
186            let response_channels = response_channels.clone();
187            async move {
188                if let Some(responding_to) = incoming.responding_to {
189                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
190                    if let Some(tx) = channel {
191                        let mut requester_resumed = barrier::channel();
192                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
193                            log::debug!(
194                                "received RPC but request future was dropped {:?}",
195                                error.0
196                            );
197                        }
198                        requester_resumed.1.recv().await;
199                    } else {
200                        log::warn!("received RPC response to unknown request {}", responding_to);
201                    }
202
203                    None
204                } else {
205                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
206                        log::error!("unable to construct a typed envelope");
207                        None
208                    })
209                }
210            }
211        });
212        (connection_id, handle_io, incoming_rx.boxed())
213    }
214
215    #[cfg(any(test, feature = "test-support"))]
216    pub async fn add_test_connection(
217        self: &Arc<Self>,
218        connection: Connection,
219        executor: Arc<gpui::executor::Background>,
220    ) -> (
221        ConnectionId,
222        impl Future<Output = anyhow::Result<()>> + Send,
223        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
224    ) {
225        let executor = executor.clone();
226        self.add_connection(connection, move |duration| executor.timer(duration))
227            .await
228    }
229
230    pub fn disconnect(&self, connection_id: ConnectionId) {
231        self.connections.write().remove(&connection_id);
232    }
233
234    pub fn reset(&self) {
235        self.connections.write().clear();
236    }
237
238    pub fn request<T: RequestMessage>(
239        &self,
240        receiver_id: ConnectionId,
241        request: T,
242    ) -> impl Future<Output = Result<T::Response>> {
243        self.request_internal(None, receiver_id, request)
244    }
245
246    pub fn forward_request<T: RequestMessage>(
247        &self,
248        sender_id: ConnectionId,
249        receiver_id: ConnectionId,
250        request: T,
251    ) -> impl Future<Output = Result<T::Response>> {
252        self.request_internal(Some(sender_id), receiver_id, request)
253    }
254
255    pub fn request_internal<T: RequestMessage>(
256        &self,
257        original_sender_id: Option<ConnectionId>,
258        receiver_id: ConnectionId,
259        request: T,
260    ) -> impl Future<Output = Result<T::Response>> {
261        let (tx, rx) = oneshot::channel();
262        let send = self.connection_state(receiver_id).and_then(|connection| {
263            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
264            connection
265                .response_channels
266                .lock()
267                .as_mut()
268                .ok_or_else(|| anyhow!("connection was closed"))?
269                .insert(message_id, tx);
270            connection
271                .outgoing_tx
272                .unbounded_send(request.into_envelope(
273                    message_id,
274                    None,
275                    original_sender_id.map(|id| id.0),
276                ))
277                .map_err(|_| anyhow!("connection was closed"))?;
278            Ok(())
279        });
280        async move {
281            send?;
282            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
283            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
284                Err(anyhow!("RPC request failed - {}", error.message))
285            } else {
286                T::Response::from_envelope(response)
287                    .ok_or_else(|| anyhow!("received response of the wrong type"))
288            }
289        }
290    }
291
292    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
293        let connection = self.connection_state(receiver_id)?;
294        let message_id = connection
295            .next_message_id
296            .fetch_add(1, atomic::Ordering::SeqCst);
297        connection
298            .outgoing_tx
299            .unbounded_send(message.into_envelope(message_id, None, None))?;
300        Ok(())
301    }
302
303    pub fn forward_send<T: EnvelopedMessage>(
304        &self,
305        sender_id: ConnectionId,
306        receiver_id: ConnectionId,
307        message: T,
308    ) -> Result<()> {
309        let connection = self.connection_state(receiver_id)?;
310        let message_id = connection
311            .next_message_id
312            .fetch_add(1, atomic::Ordering::SeqCst);
313        connection
314            .outgoing_tx
315            .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
316        Ok(())
317    }
318
319    pub fn respond<T: RequestMessage>(
320        &self,
321        receipt: Receipt<T>,
322        response: T::Response,
323    ) -> Result<()> {
324        let connection = self.connection_state(receipt.sender_id)?;
325        let message_id = connection
326            .next_message_id
327            .fetch_add(1, atomic::Ordering::SeqCst);
328        connection
329            .outgoing_tx
330            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
331        Ok(())
332    }
333
334    pub fn respond_with_error<T: RequestMessage>(
335        &self,
336        receipt: Receipt<T>,
337        response: proto::Error,
338    ) -> Result<()> {
339        let connection = self.connection_state(receipt.sender_id)?;
340        let message_id = connection
341            .next_message_id
342            .fetch_add(1, atomic::Ordering::SeqCst);
343        connection
344            .outgoing_tx
345            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
346        Ok(())
347    }
348
349    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
350        let connections = self.connections.read();
351        let connection = connections
352            .get(&connection_id)
353            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
354        Ok(connection.clone())
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use crate::TypedEnvelope;
362    use async_tungstenite::tungstenite::Message as WebSocketMessage;
363    use gpui::TestAppContext;
364
365    #[gpui::test(iterations = 50)]
366    async fn test_request_response(cx: &mut TestAppContext) {
367        let executor = cx.foreground();
368
369        // create 2 clients connected to 1 server
370        let server = Peer::new();
371        let client1 = Peer::new();
372        let client2 = Peer::new();
373
374        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
375            Connection::in_memory(cx.background());
376        let (client1_conn_id, io_task1, client1_incoming) = client1
377            .add_test_connection(client1_to_server_conn, cx.background())
378            .await;
379        let (_, io_task2, server_incoming1) = server
380            .add_test_connection(server_to_client_1_conn, cx.background())
381            .await;
382
383        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
384            Connection::in_memory(cx.background());
385        let (client2_conn_id, io_task3, client2_incoming) = client2
386            .add_test_connection(client2_to_server_conn, cx.background())
387            .await;
388        let (_, io_task4, server_incoming2) = server
389            .add_test_connection(server_to_client_2_conn, cx.background())
390            .await;
391
392        executor.spawn(io_task1).detach();
393        executor.spawn(io_task2).detach();
394        executor.spawn(io_task3).detach();
395        executor.spawn(io_task4).detach();
396        executor
397            .spawn(handle_messages(server_incoming1, server.clone()))
398            .detach();
399        executor
400            .spawn(handle_messages(client1_incoming, client1.clone()))
401            .detach();
402        executor
403            .spawn(handle_messages(server_incoming2, server.clone()))
404            .detach();
405        executor
406            .spawn(handle_messages(client2_incoming, client2.clone()))
407            .detach();
408
409        assert_eq!(
410            client1
411                .request(client1_conn_id, proto::Ping {},)
412                .await
413                .unwrap(),
414            proto::Ack {}
415        );
416
417        assert_eq!(
418            client2
419                .request(client2_conn_id, proto::Ping {},)
420                .await
421                .unwrap(),
422            proto::Ack {}
423        );
424
425        assert_eq!(
426            client1
427                .request(client1_conn_id, proto::Test { id: 1 },)
428                .await
429                .unwrap(),
430            proto::Test { id: 1 }
431        );
432
433        assert_eq!(
434            client2
435                .request(client2_conn_id, proto::Test { id: 2 })
436                .await
437                .unwrap(),
438            proto::Test { id: 2 }
439        );
440
441        client1.disconnect(client1_conn_id);
442        client2.disconnect(client1_conn_id);
443
444        async fn handle_messages(
445            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
446            peer: Arc<Peer>,
447        ) -> Result<()> {
448            while let Some(envelope) = messages.next().await {
449                let envelope = envelope.into_any();
450                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
451                    let receipt = envelope.receipt();
452                    peer.respond(receipt, proto::Ack {})?
453                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
454                {
455                    peer.respond(envelope.receipt(), envelope.payload.clone())?
456                } else {
457                    panic!("unknown message type");
458                }
459            }
460
461            Ok(())
462        }
463    }
464
465    #[gpui::test(iterations = 50)]
466    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
467        let executor = cx.foreground();
468        let server = Peer::new();
469        let client = Peer::new();
470
471        let (client_to_server_conn, server_to_client_conn, _kill) =
472            Connection::in_memory(cx.background());
473        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
474            .add_test_connection(client_to_server_conn, cx.background())
475            .await;
476        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
477            .add_test_connection(server_to_client_conn, cx.background())
478            .await;
479
480        executor.spawn(io_task1).detach();
481        executor.spawn(io_task2).detach();
482
483        executor
484            .spawn(async move {
485                let request = server_incoming
486                    .next()
487                    .await
488                    .unwrap()
489                    .into_any()
490                    .downcast::<TypedEnvelope<proto::Ping>>()
491                    .unwrap();
492
493                server
494                    .send(
495                        server_to_client_conn_id,
496                        proto::Error {
497                            message: "message 1".to_string(),
498                        },
499                    )
500                    .unwrap();
501                server
502                    .send(
503                        server_to_client_conn_id,
504                        proto::Error {
505                            message: "message 2".to_string(),
506                        },
507                    )
508                    .unwrap();
509                server.respond(request.receipt(), proto::Ack {}).unwrap();
510
511                // Prevent the connection from being dropped
512                server_incoming.next().await;
513            })
514            .detach();
515
516        let events = Arc::new(Mutex::new(Vec::new()));
517
518        let response = client.request(client_to_server_conn_id, proto::Ping {});
519        let response_task = executor.spawn({
520            let events = events.clone();
521            async move {
522                response.await.unwrap();
523                events.lock().push("response".to_string());
524            }
525        });
526
527        executor
528            .spawn({
529                let events = events.clone();
530                async move {
531                    let incoming1 = client_incoming
532                        .next()
533                        .await
534                        .unwrap()
535                        .into_any()
536                        .downcast::<TypedEnvelope<proto::Error>>()
537                        .unwrap();
538                    events.lock().push(incoming1.payload.message);
539                    let incoming2 = client_incoming
540                        .next()
541                        .await
542                        .unwrap()
543                        .into_any()
544                        .downcast::<TypedEnvelope<proto::Error>>()
545                        .unwrap();
546                    events.lock().push(incoming2.payload.message);
547
548                    // Prevent the connection from being dropped
549                    client_incoming.next().await;
550                }
551            })
552            .detach();
553
554        response_task.await;
555        assert_eq!(
556            &*events.lock(),
557            &[
558                "message 1".to_string(),
559                "message 2".to_string(),
560                "response".to_string()
561            ]
562        );
563    }
564
565    #[gpui::test(iterations = 50)]
566    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
567        let executor = cx.foreground();
568        let server = Peer::new();
569        let client = Peer::new();
570
571        let (client_to_server_conn, server_to_client_conn, _kill) =
572            Connection::in_memory(cx.background());
573        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
574            .add_test_connection(client_to_server_conn, cx.background())
575            .await;
576        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
577            .add_test_connection(server_to_client_conn, cx.background())
578            .await;
579
580        executor.spawn(io_task1).detach();
581        executor.spawn(io_task2).detach();
582
583        executor
584            .spawn(async move {
585                let request1 = server_incoming
586                    .next()
587                    .await
588                    .unwrap()
589                    .into_any()
590                    .downcast::<TypedEnvelope<proto::Ping>>()
591                    .unwrap();
592                let request2 = server_incoming
593                    .next()
594                    .await
595                    .unwrap()
596                    .into_any()
597                    .downcast::<TypedEnvelope<proto::Ping>>()
598                    .unwrap();
599
600                server
601                    .send(
602                        server_to_client_conn_id,
603                        proto::Error {
604                            message: "message 1".to_string(),
605                        },
606                    )
607                    .unwrap();
608                server
609                    .send(
610                        server_to_client_conn_id,
611                        proto::Error {
612                            message: "message 2".to_string(),
613                        },
614                    )
615                    .unwrap();
616                server.respond(request1.receipt(), proto::Ack {}).unwrap();
617                server.respond(request2.receipt(), proto::Ack {}).unwrap();
618
619                // Prevent the connection from being dropped
620                server_incoming.next().await;
621            })
622            .detach();
623
624        let events = Arc::new(Mutex::new(Vec::new()));
625
626        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
627        let request1_task = executor.spawn(request1);
628        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
629        let request2_task = executor.spawn({
630            let events = events.clone();
631            async move {
632                request2.await.unwrap();
633                events.lock().push("response 2".to_string());
634            }
635        });
636
637        executor
638            .spawn({
639                let events = events.clone();
640                async move {
641                    let incoming1 = client_incoming
642                        .next()
643                        .await
644                        .unwrap()
645                        .into_any()
646                        .downcast::<TypedEnvelope<proto::Error>>()
647                        .unwrap();
648                    events.lock().push(incoming1.payload.message);
649                    let incoming2 = client_incoming
650                        .next()
651                        .await
652                        .unwrap()
653                        .into_any()
654                        .downcast::<TypedEnvelope<proto::Error>>()
655                        .unwrap();
656                    events.lock().push(incoming2.payload.message);
657
658                    // Prevent the connection from being dropped
659                    client_incoming.next().await;
660                }
661            })
662            .detach();
663
664        // Allow the request to make some progress before dropping it.
665        cx.background().simulate_random_delay().await;
666        drop(request1_task);
667
668        request2_task.await;
669        assert_eq!(
670            &*events.lock(),
671            &[
672                "message 1".to_string(),
673                "message 2".to_string(),
674                "response 2".to_string()
675            ]
676        );
677    }
678
679    #[gpui::test(iterations = 50)]
680    async fn test_disconnect(cx: &mut TestAppContext) {
681        let executor = cx.foreground();
682
683        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
684
685        let client = Peer::new();
686        let (connection_id, io_handler, mut incoming) = client
687            .add_test_connection(client_conn, cx.background())
688            .await;
689
690        let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
691        executor
692            .spawn(async move {
693                io_handler.await.ok();
694                io_ended_tx.send(()).await.unwrap();
695            })
696            .detach();
697
698        let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
699        executor
700            .spawn(async move {
701                incoming.next().await;
702                messages_ended_tx.send(()).await.unwrap();
703            })
704            .detach();
705
706        client.disconnect(connection_id);
707
708        io_ended_rx.recv().await;
709        messages_ended_rx.recv().await;
710        assert!(server_conn
711            .send(WebSocketMessage::Binary(vec![]))
712            .await
713            .is_err());
714    }
715
716    #[gpui::test(iterations = 50)]
717    async fn test_io_error(cx: &mut TestAppContext) {
718        let executor = cx.foreground();
719        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
720
721        let client = Peer::new();
722        let (connection_id, io_handler, mut incoming) = client
723            .add_test_connection(client_conn, cx.background())
724            .await;
725        executor.spawn(io_handler).detach();
726        executor
727            .spawn(async move { incoming.next().await })
728            .detach();
729
730        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
731        let _request = server_conn.rx.next().await.unwrap().unwrap();
732
733        drop(server_conn);
734        assert_eq!(
735            response.await.unwrap_err().to_string(),
736            "connection was closed"
737        );
738    }
739}