peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use futures::{channel::oneshot, stream::BoxStream, FutureExt as _, StreamExt};
  5use parking_lot::{Mutex, RwLock};
  6use postage::{
  7    barrier, mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use smol_timeout::TimeoutExt as _;
 11use std::sync::atomic::Ordering::SeqCst;
 12use std::{
 13    collections::HashMap,
 14    fmt,
 15    future::Future,
 16    marker::PhantomData,
 17    sync::{
 18        atomic::{self, AtomicU32},
 19        Arc,
 20    },
 21    time::Duration,
 22};
 23
 24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 25pub struct ConnectionId(pub u32);
 26
 27impl fmt::Display for ConnectionId {
 28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 29        self.0.fmt(f)
 30    }
 31}
 32
 33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 34pub struct PeerId(pub u32);
 35
 36impl fmt::Display for PeerId {
 37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 38        self.0.fmt(f)
 39    }
 40}
 41
 42pub struct Receipt<T> {
 43    pub sender_id: ConnectionId,
 44    pub message_id: u32,
 45    payload_type: PhantomData<T>,
 46}
 47
 48impl<T> Clone for Receipt<T> {
 49    fn clone(&self) -> Self {
 50        Self {
 51            sender_id: self.sender_id,
 52            message_id: self.message_id,
 53            payload_type: PhantomData,
 54        }
 55    }
 56}
 57
 58impl<T> Copy for Receipt<T> {}
 59
 60pub struct TypedEnvelope<T> {
 61    pub sender_id: ConnectionId,
 62    pub original_sender_id: Option<PeerId>,
 63    pub message_id: u32,
 64    pub payload: T,
 65}
 66
 67impl<T> TypedEnvelope<T> {
 68    pub fn original_sender_id(&self) -> Result<PeerId> {
 69        self.original_sender_id
 70            .ok_or_else(|| anyhow!("missing original_sender_id"))
 71    }
 72}
 73
 74impl<T: RequestMessage> TypedEnvelope<T> {
 75    pub fn receipt(&self) -> Receipt<T> {
 76        Receipt {
 77            sender_id: self.sender_id,
 78            message_id: self.message_id,
 79            payload_type: PhantomData,
 80        }
 81    }
 82}
 83
 84pub struct Peer {
 85    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 86    next_connection_id: AtomicU32,
 87}
 88
 89#[derive(Clone)]
 90pub struct ConnectionState {
 91    outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
 92    next_message_id: Arc<AtomicU32>,
 93    response_channels:
 94        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 95}
 96
 97const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 98
 99impl Peer {
100    pub fn new() -> Arc<Self> {
101        Arc::new(Self {
102            connections: Default::default(),
103            next_connection_id: Default::default(),
104        })
105    }
106
107    pub async fn add_connection(
108        self: &Arc<Self>,
109        connection: Connection,
110    ) -> (
111        ConnectionId,
112        impl Future<Output = anyhow::Result<()>> + Send,
113        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
114    ) {
115        // For outgoing messages, use an unbounded channel so that application code
116        // can always send messages without yielding. For incoming messages, use a
117        // bounded channel so that other peers will receive backpressure if they send
118        // messages faster than this peer can process them.
119        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
120        let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
121
122        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
123        let connection_state = ConnectionState {
124            outgoing_tx,
125            next_message_id: Default::default(),
126            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
127        };
128        let mut writer = MessageStream::new(connection.tx);
129        let mut reader = MessageStream::new(connection.rx);
130
131        let this = self.clone();
132        let response_channels = connection_state.response_channels.clone();
133        let handle_io = async move {
134            let result = 'outer: loop {
135                let read_message = reader.read_message().fuse();
136                futures::pin_mut!(read_message);
137                loop {
138                    futures::select_biased! {
139                        outgoing = outgoing_rx.next().fuse() => match outgoing {
140                            Some(outgoing) => {
141                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
142                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
143                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
144                                    _ => {}
145                                }
146                            }
147                            None => break 'outer Ok(()),
148                        },
149                        incoming = read_message => match incoming {
150                            Ok(incoming) => {
151                                if incoming_tx.send(incoming).await.is_err() {
152                                    break 'outer Ok(());
153                                }
154                                break;
155                            }
156                            Err(error) => {
157                                break 'outer Err(error).context("received invalid RPC message")
158                            }
159                        },
160                    }
161                }
162            };
163
164            response_channels.lock().take();
165            this.connections.write().remove(&connection_id);
166            result
167        };
168
169        let response_channels = connection_state.response_channels.clone();
170        self.connections
171            .write()
172            .insert(connection_id, connection_state);
173
174        let incoming_rx = incoming_rx.filter_map(move |incoming| {
175            let response_channels = response_channels.clone();
176            async move {
177                if let Some(responding_to) = incoming.responding_to {
178                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
179                    if let Some(tx) = channel {
180                        let mut requester_resumed = barrier::channel();
181                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
182                            log::debug!(
183                                "received RPC but request future was dropped {:?}",
184                                error.0
185                            );
186                        }
187                        requester_resumed.1.recv().await;
188                    } else {
189                        log::warn!("received RPC response to unknown request {}", responding_to);
190                    }
191
192                    None
193                } else {
194                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
195                        Some(envelope)
196                    } else {
197                        log::error!("unable to construct a typed envelope");
198                        None
199                    }
200                }
201            }
202        });
203        (connection_id, handle_io, incoming_rx.boxed())
204    }
205
206    pub fn disconnect(&self, connection_id: ConnectionId) {
207        self.connections.write().remove(&connection_id);
208    }
209
210    pub fn reset(&self) {
211        self.connections.write().clear();
212    }
213
214    pub fn request<T: RequestMessage>(
215        &self,
216        receiver_id: ConnectionId,
217        request: T,
218    ) -> impl Future<Output = Result<T::Response>> {
219        self.request_internal(None, receiver_id, request)
220    }
221
222    pub fn forward_request<T: RequestMessage>(
223        &self,
224        sender_id: ConnectionId,
225        receiver_id: ConnectionId,
226        request: T,
227    ) -> impl Future<Output = Result<T::Response>> {
228        self.request_internal(Some(sender_id), receiver_id, request)
229    }
230
231    pub fn request_internal<T: RequestMessage>(
232        &self,
233        original_sender_id: Option<ConnectionId>,
234        receiver_id: ConnectionId,
235        request: T,
236    ) -> impl Future<Output = Result<T::Response>> {
237        let (tx, rx) = oneshot::channel();
238        let send = self.connection_state(receiver_id).and_then(|connection| {
239            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
240            connection
241                .response_channels
242                .lock()
243                .as_mut()
244                .ok_or_else(|| anyhow!("connection was closed"))?
245                .insert(message_id, tx);
246            connection
247                .outgoing_tx
248                .unbounded_send(request.into_envelope(
249                    message_id,
250                    None,
251                    original_sender_id.map(|id| id.0),
252                ))
253                .map_err(|_| anyhow!("connection was closed"))?;
254            Ok(())
255        });
256        async move {
257            send?;
258            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
259            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
260                Err(anyhow!("RPC request failed - {}", error.message))
261            } else {
262                T::Response::from_envelope(response)
263                    .ok_or_else(|| anyhow!("received response of the wrong type"))
264            }
265        }
266    }
267
268    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
269        let connection = self.connection_state(receiver_id)?;
270        let message_id = connection
271            .next_message_id
272            .fetch_add(1, atomic::Ordering::SeqCst);
273        connection
274            .outgoing_tx
275            .unbounded_send(message.into_envelope(message_id, None, None))?;
276        Ok(())
277    }
278
279    pub fn forward_send<T: EnvelopedMessage>(
280        &self,
281        sender_id: ConnectionId,
282        receiver_id: ConnectionId,
283        message: T,
284    ) -> Result<()> {
285        let connection = self.connection_state(receiver_id)?;
286        let message_id = connection
287            .next_message_id
288            .fetch_add(1, atomic::Ordering::SeqCst);
289        connection
290            .outgoing_tx
291            .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
292        Ok(())
293    }
294
295    pub fn respond<T: RequestMessage>(
296        &self,
297        receipt: Receipt<T>,
298        response: T::Response,
299    ) -> Result<()> {
300        let connection = self.connection_state(receipt.sender_id)?;
301        let message_id = connection
302            .next_message_id
303            .fetch_add(1, atomic::Ordering::SeqCst);
304        connection
305            .outgoing_tx
306            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
307        Ok(())
308    }
309
310    pub fn respond_with_error<T: RequestMessage>(
311        &self,
312        receipt: Receipt<T>,
313        response: proto::Error,
314    ) -> Result<()> {
315        let connection = self.connection_state(receipt.sender_id)?;
316        let message_id = connection
317            .next_message_id
318            .fetch_add(1, atomic::Ordering::SeqCst);
319        connection
320            .outgoing_tx
321            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
322        Ok(())
323    }
324
325    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
326        let connections = self.connections.read();
327        let connection = connections
328            .get(&connection_id)
329            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
330        Ok(connection.clone())
331    }
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use crate::TypedEnvelope;
338    use async_tungstenite::tungstenite::Message as WebSocketMessage;
339    use gpui::TestAppContext;
340
341    #[gpui::test(iterations = 50)]
342    async fn test_request_response(cx: &mut TestAppContext) {
343        let executor = cx.foreground();
344
345        // create 2 clients connected to 1 server
346        let server = Peer::new();
347        let client1 = Peer::new();
348        let client2 = Peer::new();
349
350        let (client1_to_server_conn, server_to_client_1_conn, _) =
351            Connection::in_memory(cx.background());
352        let (client1_conn_id, io_task1, client1_incoming) =
353            client1.add_connection(client1_to_server_conn).await;
354        let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
355
356        let (client2_to_server_conn, server_to_client_2_conn, _) =
357            Connection::in_memory(cx.background());
358        let (client2_conn_id, io_task3, client2_incoming) =
359            client2.add_connection(client2_to_server_conn).await;
360        let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
361
362        executor.spawn(io_task1).detach();
363        executor.spawn(io_task2).detach();
364        executor.spawn(io_task3).detach();
365        executor.spawn(io_task4).detach();
366        executor
367            .spawn(handle_messages(server_incoming1, server.clone()))
368            .detach();
369        executor
370            .spawn(handle_messages(client1_incoming, client1.clone()))
371            .detach();
372        executor
373            .spawn(handle_messages(server_incoming2, server.clone()))
374            .detach();
375        executor
376            .spawn(handle_messages(client2_incoming, client2.clone()))
377            .detach();
378
379        assert_eq!(
380            client1
381                .request(client1_conn_id, proto::Ping {},)
382                .await
383                .unwrap(),
384            proto::Ack {}
385        );
386
387        assert_eq!(
388            client2
389                .request(client2_conn_id, proto::Ping {},)
390                .await
391                .unwrap(),
392            proto::Ack {}
393        );
394
395        assert_eq!(
396            client1
397                .request(client1_conn_id, proto::Test { id: 1 },)
398                .await
399                .unwrap(),
400            proto::Test { id: 1 }
401        );
402
403        assert_eq!(
404            client2
405                .request(client2_conn_id, proto::Test { id: 2 })
406                .await
407                .unwrap(),
408            proto::Test { id: 2 }
409        );
410
411        client1.disconnect(client1_conn_id);
412        client2.disconnect(client1_conn_id);
413
414        async fn handle_messages(
415            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
416            peer: Arc<Peer>,
417        ) -> Result<()> {
418            while let Some(envelope) = messages.next().await {
419                let envelope = envelope.into_any();
420                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
421                    let receipt = envelope.receipt();
422                    peer.respond(receipt, proto::Ack {})?
423                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
424                {
425                    peer.respond(envelope.receipt(), envelope.payload.clone())?
426                } else {
427                    panic!("unknown message type");
428                }
429            }
430
431            Ok(())
432        }
433    }
434
435    #[gpui::test(iterations = 50)]
436    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
437        let executor = cx.foreground();
438        let server = Peer::new();
439        let client = Peer::new();
440
441        let (client_to_server_conn, server_to_client_conn, _) =
442            Connection::in_memory(cx.background());
443        let (client_to_server_conn_id, io_task1, mut client_incoming) =
444            client.add_connection(client_to_server_conn).await;
445        let (server_to_client_conn_id, io_task2, mut server_incoming) =
446            server.add_connection(server_to_client_conn).await;
447
448        executor.spawn(io_task1).detach();
449        executor.spawn(io_task2).detach();
450
451        executor
452            .spawn(async move {
453                let request = server_incoming
454                    .next()
455                    .await
456                    .unwrap()
457                    .into_any()
458                    .downcast::<TypedEnvelope<proto::Ping>>()
459                    .unwrap();
460
461                server
462                    .send(
463                        server_to_client_conn_id,
464                        proto::Error {
465                            message: "message 1".to_string(),
466                        },
467                    )
468                    .unwrap();
469                server
470                    .send(
471                        server_to_client_conn_id,
472                        proto::Error {
473                            message: "message 2".to_string(),
474                        },
475                    )
476                    .unwrap();
477                server.respond(request.receipt(), proto::Ack {}).unwrap();
478
479                // Prevent the connection from being dropped
480                server_incoming.next().await;
481            })
482            .detach();
483
484        let events = Arc::new(Mutex::new(Vec::new()));
485
486        let response = client.request(client_to_server_conn_id, proto::Ping {});
487        let response_task = executor.spawn({
488            let events = events.clone();
489            async move {
490                response.await.unwrap();
491                events.lock().push("response".to_string());
492            }
493        });
494
495        executor
496            .spawn({
497                let events = events.clone();
498                async move {
499                    let incoming1 = client_incoming
500                        .next()
501                        .await
502                        .unwrap()
503                        .into_any()
504                        .downcast::<TypedEnvelope<proto::Error>>()
505                        .unwrap();
506                    events.lock().push(incoming1.payload.message);
507                    let incoming2 = client_incoming
508                        .next()
509                        .await
510                        .unwrap()
511                        .into_any()
512                        .downcast::<TypedEnvelope<proto::Error>>()
513                        .unwrap();
514                    events.lock().push(incoming2.payload.message);
515
516                    // Prevent the connection from being dropped
517                    client_incoming.next().await;
518                }
519            })
520            .detach();
521
522        response_task.await;
523        assert_eq!(
524            &*events.lock(),
525            &[
526                "message 1".to_string(),
527                "message 2".to_string(),
528                "response".to_string()
529            ]
530        );
531    }
532
533    #[gpui::test(iterations = 50)]
534    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
535        let executor = cx.foreground();
536        let server = Peer::new();
537        let client = Peer::new();
538
539        let (client_to_server_conn, server_to_client_conn, _) =
540            Connection::in_memory(cx.background());
541        let (client_to_server_conn_id, io_task1, mut client_incoming) =
542            client.add_connection(client_to_server_conn).await;
543        let (server_to_client_conn_id, io_task2, mut server_incoming) =
544            server.add_connection(server_to_client_conn).await;
545
546        executor.spawn(io_task1).detach();
547        executor.spawn(io_task2).detach();
548
549        executor
550            .spawn(async move {
551                let request1 = server_incoming
552                    .next()
553                    .await
554                    .unwrap()
555                    .into_any()
556                    .downcast::<TypedEnvelope<proto::Ping>>()
557                    .unwrap();
558                let request2 = server_incoming
559                    .next()
560                    .await
561                    .unwrap()
562                    .into_any()
563                    .downcast::<TypedEnvelope<proto::Ping>>()
564                    .unwrap();
565
566                server
567                    .send(
568                        server_to_client_conn_id,
569                        proto::Error {
570                            message: "message 1".to_string(),
571                        },
572                    )
573                    .unwrap();
574                server
575                    .send(
576                        server_to_client_conn_id,
577                        proto::Error {
578                            message: "message 2".to_string(),
579                        },
580                    )
581                    .unwrap();
582                server.respond(request1.receipt(), proto::Ack {}).unwrap();
583                server.respond(request2.receipt(), proto::Ack {}).unwrap();
584
585                // Prevent the connection from being dropped
586                server_incoming.next().await;
587            })
588            .detach();
589
590        let events = Arc::new(Mutex::new(Vec::new()));
591
592        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
593        let request1_task = executor.spawn(request1);
594        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
595        let request2_task = executor.spawn({
596            let events = events.clone();
597            async move {
598                request2.await.unwrap();
599                events.lock().push("response 2".to_string());
600            }
601        });
602
603        executor
604            .spawn({
605                let events = events.clone();
606                async move {
607                    let incoming1 = client_incoming
608                        .next()
609                        .await
610                        .unwrap()
611                        .into_any()
612                        .downcast::<TypedEnvelope<proto::Error>>()
613                        .unwrap();
614                    events.lock().push(incoming1.payload.message);
615                    let incoming2 = client_incoming
616                        .next()
617                        .await
618                        .unwrap()
619                        .into_any()
620                        .downcast::<TypedEnvelope<proto::Error>>()
621                        .unwrap();
622                    events.lock().push(incoming2.payload.message);
623
624                    // Prevent the connection from being dropped
625                    client_incoming.next().await;
626                }
627            })
628            .detach();
629
630        // Allow the request to make some progress before dropping it.
631        cx.background().simulate_random_delay().await;
632        drop(request1_task);
633
634        request2_task.await;
635        assert_eq!(
636            &*events.lock(),
637            &[
638                "message 1".to_string(),
639                "message 2".to_string(),
640                "response 2".to_string()
641            ]
642        );
643    }
644
645    #[gpui::test(iterations = 50)]
646    async fn test_disconnect(cx: &mut TestAppContext) {
647        let executor = cx.foreground();
648
649        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
650
651        let client = Peer::new();
652        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
653
654        let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
655        executor
656            .spawn(async move {
657                io_handler.await.ok();
658                io_ended_tx.send(()).await.unwrap();
659            })
660            .detach();
661
662        let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
663        executor
664            .spawn(async move {
665                incoming.next().await;
666                messages_ended_tx.send(()).await.unwrap();
667            })
668            .detach();
669
670        client.disconnect(connection_id);
671
672        io_ended_rx.recv().await;
673        messages_ended_rx.recv().await;
674        assert!(server_conn
675            .send(WebSocketMessage::Binary(vec![]))
676            .await
677            .is_err());
678    }
679
680    #[gpui::test(iterations = 50)]
681    async fn test_io_error(cx: &mut TestAppContext) {
682        let executor = cx.foreground();
683        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
684
685        let client = Peer::new();
686        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
687        executor.spawn(io_handler).detach();
688        executor
689            .spawn(async move { incoming.next().await })
690            .detach();
691
692        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
693        let _request = server_conn.rx.next().await.unwrap().unwrap();
694
695        drop(server_conn);
696        assert_eq!(
697            response.await.unwrap_err().to_string(),
698            "connection was closed"
699        );
700    }
701}