peer.rs

   1use super::{
   2    Connection,
   3    message_stream::{Message, MessageStream},
   4    proto::{
   5        self, AnyTypedEnvelope, EnvelopedMessage, PeerId, Receipt, RequestMessage, TypedEnvelope,
   6    },
   7};
   8use anyhow::{Context as _, Result, anyhow};
   9use collections::HashMap;
  10use futures::{
  11    FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
  12    channel::{mpsc, oneshot},
  13    stream::BoxStream,
  14};
  15use parking_lot::{Mutex, RwLock};
  16use proto::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
  17use serde::{Serialize, ser::SerializeStruct};
  18use std::{
  19    fmt, future,
  20    future::Future,
  21    sync::atomic::Ordering::SeqCst,
  22    sync::{
  23        Arc,
  24        atomic::{self, AtomicU32},
  25    },
  26    time::Duration,
  27    time::Instant,
  28};
  29
  30#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
  31pub struct ConnectionId {
  32    pub owner_id: u32,
  33    pub id: u32,
  34}
  35
  36impl From<ConnectionId> for PeerId {
  37    fn from(id: ConnectionId) -> Self {
  38        PeerId {
  39            owner_id: id.owner_id,
  40            id: id.id,
  41        }
  42    }
  43}
  44
  45impl From<PeerId> for ConnectionId {
  46    fn from(peer_id: PeerId) -> Self {
  47        Self {
  48            owner_id: peer_id.owner_id,
  49            id: peer_id.id,
  50        }
  51    }
  52}
  53
  54impl fmt::Display for ConnectionId {
  55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  56        write!(f, "{}/{}", self.owner_id, self.id)
  57    }
  58}
  59
  60pub struct Peer {
  61    epoch: AtomicU32,
  62    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
  63    next_connection_id: AtomicU32,
  64}
  65
  66#[derive(Clone, Serialize)]
  67pub struct ConnectionState {
  68    #[serde(skip)]
  69    outgoing_tx: mpsc::UnboundedSender<Message>,
  70    next_message_id: Arc<AtomicU32>,
  71    #[allow(clippy::type_complexity)]
  72    #[serde(skip)]
  73    response_channels: Arc<
  74        Mutex<
  75            Option<
  76                HashMap<
  77                    u32,
  78                    oneshot::Sender<(proto::Envelope, std::time::Instant, oneshot::Sender<()>)>,
  79                >,
  80            >,
  81        >,
  82    >,
  83    #[allow(clippy::type_complexity)]
  84    #[serde(skip)]
  85    stream_response_channels: Arc<
  86        Mutex<
  87            Option<
  88                HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
  89            >,
  90        >,
  91    >,
  92}
  93
  94const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
  95const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
  96pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
  97
  98impl Peer {
  99    pub fn new(epoch: u32) -> Arc<Self> {
 100        Arc::new(Self {
 101            epoch: AtomicU32::new(epoch),
 102            connections: Default::default(),
 103            next_connection_id: Default::default(),
 104        })
 105    }
 106
 107    pub fn epoch(&self) -> u32 {
 108        self.epoch.load(SeqCst)
 109    }
 110
 111    pub fn add_connection<F, Fut, Out>(
 112        self: &Arc<Self>,
 113        connection: Connection,
 114        create_timer: F,
 115    ) -> (
 116        ConnectionId,
 117        impl Future<Output = anyhow::Result<()>> + Send + use<F, Fut, Out>,
 118        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
 119    )
 120    where
 121        F: Send + Fn(Duration) -> Fut,
 122        Fut: Send + Future<Output = Out>,
 123        Out: Send,
 124    {
 125        // For outgoing messages, use an unbounded channel so that application code
 126        // can always send messages without yielding. For incoming messages, use a
 127        // bounded channel so that other peers will receive backpressure if they send
 128        // messages faster than this peer can process them.
 129        #[cfg(any(test, feature = "test-support"))]
 130        const INCOMING_BUFFER_SIZE: usize = 1;
 131        #[cfg(not(any(test, feature = "test-support")))]
 132        const INCOMING_BUFFER_SIZE: usize = 256;
 133        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
 134        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
 135
 136        let connection_id = ConnectionId {
 137            owner_id: self.epoch.load(SeqCst),
 138            id: self.next_connection_id.fetch_add(1, SeqCst),
 139        };
 140        let connection_state = ConnectionState {
 141            outgoing_tx,
 142            next_message_id: Default::default(),
 143            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
 144            stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
 145        };
 146        let mut writer = MessageStream::new(connection.tx);
 147        let mut reader = MessageStream::new(connection.rx);
 148
 149        let this = self.clone();
 150        let response_channels = connection_state.response_channels.clone();
 151        let stream_response_channels = connection_state.stream_response_channels.clone();
 152
 153        let handle_io = async move {
 154            tracing::trace!(%connection_id, "handle io future: start");
 155
 156            let _end_connection = util::defer(|| {
 157                response_channels.lock().take();
 158                if let Some(channels) = stream_response_channels.lock().take() {
 159                    for channel in channels.values() {
 160                        let _ = channel.unbounded_send((
 161                            Err(anyhow!("connection closed")),
 162                            oneshot::channel().0,
 163                        ));
 164                    }
 165                }
 166                this.connections.write().remove(&connection_id);
 167                tracing::trace!(%connection_id, "handle io future: end");
 168            });
 169
 170            // Send messages on this frequency so the connection isn't closed.
 171            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
 172            futures::pin_mut!(keepalive_timer);
 173
 174            // Disconnect if we don't receive messages at least this frequently.
 175            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
 176            futures::pin_mut!(receive_timeout);
 177
 178            loop {
 179                tracing::trace!(%connection_id, "outer loop iteration start");
 180                let read_message = reader.read().fuse();
 181                futures::pin_mut!(read_message);
 182
 183                loop {
 184                    tracing::trace!(%connection_id, "inner loop iteration start");
 185                    futures::select_biased! {
 186                        outgoing = outgoing_rx.next().fuse() => match outgoing {
 187                            Some(outgoing) => {
 188                                tracing::trace!(%connection_id, "outgoing rpc message: writing");
 189                                futures::select_biased! {
 190                                    result = writer.write(outgoing).fuse() => {
 191                                        tracing::trace!(%connection_id, "outgoing rpc message: done writing");
 192                                        result.context("failed to write RPC message")?;
 193                                        tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
 194                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
 195                                    }
 196                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
 197                                        tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
 198                                        anyhow::bail!("timed out writing message");
 199                                    }
 200                                }
 201                            }
 202                            None => {
 203                                tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
 204                                return Ok(())
 205                            },
 206                        },
 207                        _ = keepalive_timer => {
 208                            tracing::trace!(%connection_id, "keepalive interval: pinging");
 209                            futures::select_biased! {
 210                                result = writer.write(Message::Ping).fuse() => {
 211                                    tracing::trace!(%connection_id, "keepalive interval: done pinging");
 212                                    result.context("failed to send keepalive")?;
 213                                    tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
 214                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
 215                                }
 216                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
 217                                    tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
 218                                    anyhow::bail!("timed out sending keepalive");
 219                                }
 220                            }
 221                        }
 222                        incoming = read_message => {
 223                            let incoming = incoming.context("error reading rpc message from socket")?;
 224                            tracing::trace!(%connection_id, "incoming rpc message: received");
 225                            tracing::trace!(%connection_id, "receive timeout: resetting");
 226                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
 227                            if let (Message::Envelope(incoming), received_at) = incoming {
 228                                tracing::trace!(%connection_id, "incoming rpc message: processing");
 229                                futures::select_biased! {
 230                                    result = incoming_tx.send((incoming, received_at)).fuse() => match result {
 231                                        Ok(_) => {
 232                                            tracing::trace!(%connection_id, "incoming rpc message: processed");
 233                                        }
 234                                        Err(_) => {
 235                                            tracing::trace!(%connection_id, "incoming rpc message: channel closed");
 236                                            return Ok(())
 237                                        }
 238                                    },
 239                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
 240                                        tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
 241                                        anyhow::bail!("timed out processing incoming message");
 242                                    }
 243                                }
 244                            }
 245                            break;
 246                        },
 247                        _ = receive_timeout => {
 248                            tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
 249                            anyhow::bail!("delay between messages too long");
 250                        }
 251                    }
 252                }
 253            }
 254        };
 255
 256        let response_channels = connection_state.response_channels.clone();
 257        let stream_response_channels = connection_state.stream_response_channels.clone();
 258        self.connections
 259            .write()
 260            .insert(connection_id, connection_state);
 261
 262        let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
 263            let response_channels = response_channels.clone();
 264            let stream_response_channels = stream_response_channels.clone();
 265            async move {
 266                let message_id = incoming.id;
 267                tracing::trace!(?incoming, "incoming message future: start");
 268                let _end = util::defer(move || {
 269                    tracing::trace!(%connection_id, message_id, "incoming message future: end");
 270                });
 271
 272                if let Some(responding_to) = incoming.responding_to {
 273                    tracing::trace!(
 274                        %connection_id,
 275                        message_id,
 276                        responding_to,
 277                        "incoming response: received"
 278                    );
 279                    let response_channel =
 280                        response_channels.lock().as_mut()?.remove(&responding_to);
 281                    let stream_response_channel = stream_response_channels
 282                        .lock()
 283                        .as_ref()?
 284                        .get(&responding_to)
 285                        .cloned();
 286
 287                    if let Some(tx) = response_channel {
 288                        let requester_resumed = oneshot::channel();
 289                        if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
 290                            tracing::trace!(
 291                                %connection_id,
 292                                message_id,
 293                                responding_to = responding_to,
 294                                ?error,
 295                                "incoming response: request future dropped",
 296                            );
 297                        }
 298
 299                        tracing::trace!(
 300                            %connection_id,
 301                            message_id,
 302                            responding_to,
 303                            "incoming response: waiting to resume requester"
 304                        );
 305                        let _ = requester_resumed.1.await;
 306                        tracing::trace!(
 307                            %connection_id,
 308                            message_id,
 309                            responding_to,
 310                            "incoming response: requester resumed"
 311                        );
 312                    } else if let Some(tx) = stream_response_channel {
 313                        let requester_resumed = oneshot::channel();
 314                        if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
 315                            tracing::debug!(
 316                                %connection_id,
 317                                message_id,
 318                                responding_to = responding_to,
 319                                ?error,
 320                                "incoming stream response: request future dropped",
 321                            );
 322                        }
 323
 324                        tracing::debug!(
 325                            %connection_id,
 326                            message_id,
 327                            responding_to,
 328                            "incoming stream response: waiting to resume requester"
 329                        );
 330                        let _ = requester_resumed.1.await;
 331                        tracing::debug!(
 332                            %connection_id,
 333                            message_id,
 334                            responding_to,
 335                            "incoming stream response: requester resumed"
 336                        );
 337                    } else {
 338                        let message_type = proto::build_typed_envelope(
 339                            connection_id.into(),
 340                            received_at,
 341                            incoming,
 342                        )
 343                        .map(|p| p.payload_type_name());
 344                        tracing::warn!(
 345                            %connection_id,
 346                            message_id,
 347                            responding_to,
 348                            message_type,
 349                            "incoming response: unknown request"
 350                        );
 351                    }
 352
 353                    None
 354                } else {
 355                    tracing::trace!(%connection_id, message_id, "incoming message: received");
 356                    proto::build_typed_envelope(connection_id.into(), received_at, incoming)
 357                        .or_else(|| {
 358                            tracing::error!(
 359                                %connection_id,
 360                                message_id,
 361                                "unable to construct a typed envelope"
 362                            );
 363                            None
 364                        })
 365                }
 366            }
 367        });
 368        (connection_id, handle_io, incoming_rx.boxed())
 369    }
 370
 371    #[cfg(any(test, feature = "test-support"))]
 372    pub fn add_test_connection(
 373        self: &Arc<Self>,
 374        connection: Connection,
 375        executor: gpui::BackgroundExecutor,
 376    ) -> (
 377        ConnectionId,
 378        impl Future<Output = anyhow::Result<()>> + Send + use<>,
 379        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
 380    ) {
 381        let executor = executor.clone();
 382        self.add_connection(connection, move |duration| executor.timer(duration))
 383    }
 384
 385    pub fn disconnect(&self, connection_id: ConnectionId) {
 386        self.connections.write().remove(&connection_id);
 387    }
 388
 389    #[cfg(any(test, feature = "test-support"))]
 390    pub fn reset(&self, epoch: u32) {
 391        self.next_connection_id.store(0, SeqCst);
 392        self.epoch.store(epoch, SeqCst);
 393    }
 394
 395    pub fn teardown(&self) {
 396        self.connections.write().clear();
 397    }
 398
 399    /// Make a request and wait for a response.
 400    pub fn request<T: RequestMessage>(
 401        &self,
 402        receiver_id: ConnectionId,
 403        request: T,
 404    ) -> impl Future<Output = Result<T::Response>> + use<T> {
 405        self.request_internal(None, receiver_id, request)
 406            .map_ok(|envelope| envelope.payload)
 407    }
 408
 409    pub fn request_envelope<T: RequestMessage>(
 410        &self,
 411        receiver_id: ConnectionId,
 412        request: T,
 413    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
 414        self.request_internal(None, receiver_id, request)
 415    }
 416
 417    pub fn forward_request<T: RequestMessage>(
 418        &self,
 419        sender_id: ConnectionId,
 420        receiver_id: ConnectionId,
 421        request: T,
 422    ) -> impl Future<Output = Result<T::Response>> {
 423        self.request_internal(Some(sender_id), receiver_id, request)
 424            .map_ok(|envelope| envelope.payload)
 425    }
 426
 427    fn request_internal<T: RequestMessage>(
 428        &self,
 429        original_sender_id: Option<ConnectionId>,
 430        receiver_id: ConnectionId,
 431        request: T,
 432    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
 433        let envelope = request.into_envelope(0, None, original_sender_id.map(Into::into));
 434        let response = self.request_dynamic(receiver_id, envelope, T::NAME);
 435        async move {
 436            let (response, received_at) = response.await?;
 437            Ok(TypedEnvelope {
 438                message_id: response.id,
 439                sender_id: receiver_id.into(),
 440                original_sender_id: response.original_sender_id,
 441                payload: T::Response::from_envelope(response)
 442                    .context("received response of the wrong type")?,
 443                received_at,
 444            })
 445        }
 446    }
 447
 448    /// Make a request and wait for a response.
 449    ///
 450    /// The caller must make sure to deserialize the response into the request's
 451    /// response type. This interface is only useful in trait objects, where
 452    /// generics can't be used. If you have a concrete type, use `request`.
 453    pub fn request_dynamic(
 454        &self,
 455        receiver_id: ConnectionId,
 456        mut envelope: proto::Envelope,
 457        type_name: &'static str,
 458    ) -> impl Future<Output = Result<(proto::Envelope, Instant)>> + use<> {
 459        let (tx, rx) = oneshot::channel();
 460        let send = self.connection_state(receiver_id).and_then(|connection| {
 461            envelope.id = connection.next_message_id.fetch_add(1, SeqCst);
 462            connection
 463                .response_channels
 464                .lock()
 465                .as_mut()
 466                .context("connection was closed")?
 467                .insert(envelope.id, tx);
 468            connection
 469                .outgoing_tx
 470                .unbounded_send(Message::Envelope(envelope))
 471                .context("connection was closed")?;
 472            Ok(())
 473        });
 474        async move {
 475            send?;
 476            let (response, received_at, _barrier) = rx.await.context("connection was closed")?;
 477            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
 478                return Err(RpcError::from_proto(error, type_name));
 479            }
 480            Ok((response, received_at))
 481        }
 482    }
 483
 484    pub fn request_stream<T: RequestMessage>(
 485        &self,
 486        receiver_id: ConnectionId,
 487        request: T,
 488    ) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
 489        let (tx, rx) = mpsc::unbounded();
 490        let send = self.connection_state(receiver_id).and_then(|connection| {
 491            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
 492            let stream_response_channels = connection.stream_response_channels.clone();
 493            stream_response_channels
 494                .lock()
 495                .as_mut()
 496                .context("connection was closed")?
 497                .insert(message_id, tx);
 498            connection
 499                .outgoing_tx
 500                .unbounded_send(Message::Envelope(
 501                    request.into_envelope(message_id, None, None),
 502                ))
 503                .context("connection was closed")?;
 504            Ok((message_id, stream_response_channels))
 505        });
 506
 507        async move {
 508            let (message_id, stream_response_channels) = send?;
 509            let stream_response_channels = Arc::downgrade(&stream_response_channels);
 510
 511            Ok(rx.filter_map(move |(response, _barrier)| {
 512                let stream_response_channels = stream_response_channels.clone();
 513                future::ready(match response {
 514                    Ok(response) => {
 515                        if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
 516                            Some(Err(RpcError::from_proto(error, T::NAME)))
 517                        } else if let Some(proto::envelope::Payload::EndStream(_)) =
 518                            &response.payload
 519                        {
 520                            // Remove the transmitting end of the response channel to end the stream.
 521                            if let Some(channels) = stream_response_channels.upgrade()
 522                                && let Some(channels) = channels.lock().as_mut()
 523                            {
 524                                channels.remove(&message_id);
 525                            }
 526                            None
 527                        } else {
 528                            Some(
 529                                T::Response::from_envelope(response)
 530                                    .context("received response of the wrong type"),
 531                            )
 532                        }
 533                    }
 534                    Err(error) => Some(Err(error)),
 535                })
 536            }))
 537        }
 538    }
 539
 540    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
 541        let connection = self.connection_state(receiver_id)?;
 542        let message_id = connection
 543            .next_message_id
 544            .fetch_add(1, atomic::Ordering::SeqCst);
 545        connection.outgoing_tx.unbounded_send(Message::Envelope(
 546            message.into_envelope(message_id, None, None),
 547        ))?;
 548        Ok(())
 549    }
 550
 551    pub fn send_dynamic(&self, receiver_id: ConnectionId, message: proto::Envelope) -> Result<()> {
 552        let connection = self.connection_state(receiver_id)?;
 553        connection
 554            .outgoing_tx
 555            .unbounded_send(Message::Envelope(message))?;
 556        Ok(())
 557    }
 558
 559    pub fn forward_send<T: EnvelopedMessage>(
 560        &self,
 561        sender_id: ConnectionId,
 562        receiver_id: ConnectionId,
 563        message: T,
 564    ) -> Result<()> {
 565        let connection = self.connection_state(receiver_id)?;
 566        let message_id = connection
 567            .next_message_id
 568            .fetch_add(1, atomic::Ordering::SeqCst);
 569        connection
 570            .outgoing_tx
 571            .unbounded_send(Message::Envelope(message.into_envelope(
 572                message_id,
 573                None,
 574                Some(sender_id.into()),
 575            )))?;
 576        Ok(())
 577    }
 578
 579    pub fn respond<T: RequestMessage>(
 580        &self,
 581        receipt: Receipt<T>,
 582        response: T::Response,
 583    ) -> Result<()> {
 584        let connection = self.connection_state(receipt.sender_id.into())?;
 585        let message_id = connection
 586            .next_message_id
 587            .fetch_add(1, atomic::Ordering::SeqCst);
 588        connection
 589            .outgoing_tx
 590            .unbounded_send(Message::Envelope(response.into_envelope(
 591                message_id,
 592                Some(receipt.message_id),
 593                None,
 594            )))?;
 595        Ok(())
 596    }
 597
 598    pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
 599        let connection = self.connection_state(receipt.sender_id.into())?;
 600        let message_id = connection
 601            .next_message_id
 602            .fetch_add(1, atomic::Ordering::SeqCst);
 603
 604        let message = proto::EndStream {};
 605
 606        connection
 607            .outgoing_tx
 608            .unbounded_send(Message::Envelope(message.into_envelope(
 609                message_id,
 610                Some(receipt.message_id),
 611                None,
 612            )))?;
 613        Ok(())
 614    }
 615
 616    pub fn respond_with_error<T: RequestMessage>(
 617        &self,
 618        receipt: Receipt<T>,
 619        response: proto::Error,
 620    ) -> Result<()> {
 621        let connection = self.connection_state(receipt.sender_id.into())?;
 622        let message_id = connection
 623            .next_message_id
 624            .fetch_add(1, atomic::Ordering::SeqCst);
 625        connection
 626            .outgoing_tx
 627            .unbounded_send(Message::Envelope(response.into_envelope(
 628                message_id,
 629                Some(receipt.message_id),
 630                None,
 631            )))?;
 632        Ok(())
 633    }
 634
 635    pub fn respond_with_unhandled_message(
 636        &self,
 637        sender_id: ConnectionId,
 638        request_message_id: u32,
 639        message_type_name: &'static str,
 640    ) -> Result<()> {
 641        let connection = self.connection_state(sender_id)?;
 642        let response = ErrorCode::Internal
 643            .message(format!("message {} was not handled", message_type_name))
 644            .to_proto();
 645        let message_id = connection
 646            .next_message_id
 647            .fetch_add(1, atomic::Ordering::SeqCst);
 648        connection
 649            .outgoing_tx
 650            .unbounded_send(Message::Envelope(response.into_envelope(
 651                message_id,
 652                Some(request_message_id),
 653                None,
 654            )))?;
 655        Ok(())
 656    }
 657
 658    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
 659        let connections = self.connections.read();
 660        let connection = connections
 661            .get(&connection_id)
 662            .with_context(|| format!("no such connection: {connection_id}"))?;
 663        Ok(connection.clone())
 664    }
 665}
 666
 667impl Serialize for Peer {
 668    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 669    where
 670        S: serde::Serializer,
 671    {
 672        let mut state = serializer.serialize_struct("Peer", 2)?;
 673        state.serialize_field("connections", &*self.connections.read())?;
 674        state.end()
 675    }
 676}
 677
 678#[cfg(test)]
 679mod tests {
 680    use super::*;
 681    use async_tungstenite::tungstenite::Message as WebSocketMessage;
 682    use gpui::TestAppContext;
 683
 684    fn init_logger() {
 685        zlog::init_test();
 686    }
 687
 688    #[gpui::test(iterations = 50)]
 689    async fn test_request_response(cx: &mut TestAppContext) {
 690        init_logger();
 691
 692        let executor = cx.executor();
 693
 694        // create 2 clients connected to 1 server
 695        let server = Peer::new(0);
 696        let client1 = Peer::new(0);
 697        let client2 = Peer::new(0);
 698
 699        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
 700            Connection::in_memory(cx.executor());
 701        let (client1_conn_id, io_task1, client1_incoming) =
 702            client1.add_test_connection(client1_to_server_conn, cx.executor());
 703        let (_, io_task2, server_incoming1) =
 704            server.add_test_connection(server_to_client_1_conn, cx.executor());
 705
 706        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
 707            Connection::in_memory(cx.executor());
 708        let (client2_conn_id, io_task3, client2_incoming) =
 709            client2.add_test_connection(client2_to_server_conn, cx.executor());
 710        let (_, io_task4, server_incoming2) =
 711            server.add_test_connection(server_to_client_2_conn, cx.executor());
 712
 713        executor.spawn(io_task1).detach();
 714        executor.spawn(io_task2).detach();
 715        executor.spawn(io_task3).detach();
 716        executor.spawn(io_task4).detach();
 717        executor
 718            .spawn(handle_messages(server_incoming1, server.clone()))
 719            .detach();
 720        executor
 721            .spawn(handle_messages(client1_incoming, client1.clone()))
 722            .detach();
 723        executor
 724            .spawn(handle_messages(server_incoming2, server.clone()))
 725            .detach();
 726        executor
 727            .spawn(handle_messages(client2_incoming, client2.clone()))
 728            .detach();
 729
 730        assert_eq!(
 731            client1
 732                .request(client1_conn_id, proto::Ping {},)
 733                .await
 734                .unwrap(),
 735            proto::Ack {}
 736        );
 737
 738        assert_eq!(
 739            client2
 740                .request(client2_conn_id, proto::Ping {},)
 741                .await
 742                .unwrap(),
 743            proto::Ack {}
 744        );
 745
 746        assert_eq!(
 747            client1
 748                .request(client1_conn_id, proto::Test { id: 1 },)
 749                .await
 750                .unwrap(),
 751            proto::Test { id: 1 }
 752        );
 753
 754        assert_eq!(
 755            client2
 756                .request(client2_conn_id, proto::Test { id: 2 })
 757                .await
 758                .unwrap(),
 759            proto::Test { id: 2 }
 760        );
 761
 762        client1.disconnect(client1_conn_id);
 763        client2.disconnect(client1_conn_id);
 764
 765        async fn handle_messages(
 766            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
 767            peer: Arc<Peer>,
 768        ) -> Result<()> {
 769            while let Some(envelope) = messages.next().await {
 770                let envelope = envelope.into_any();
 771                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
 772                    let receipt = envelope.receipt();
 773                    peer.respond(receipt, proto::Ack {})?
 774                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
 775                {
 776                    peer.respond(envelope.receipt(), envelope.payload.clone())?
 777                } else {
 778                    panic!("unknown message type");
 779                }
 780            }
 781
 782            Ok(())
 783        }
 784    }
 785
 786    #[gpui::test(iterations = 50)]
 787    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
 788        let executor = cx.executor();
 789        let server = Peer::new(0);
 790        let client = Peer::new(0);
 791
 792        let (client_to_server_conn, server_to_client_conn, _kill) =
 793            Connection::in_memory(executor.clone());
 794        let (client_to_server_conn_id, io_task1, mut client_incoming) =
 795            client.add_test_connection(client_to_server_conn, executor.clone());
 796
 797        let (server_to_client_conn_id, io_task2, mut server_incoming) =
 798            server.add_test_connection(server_to_client_conn, executor.clone());
 799
 800        executor.spawn(io_task1).detach();
 801        executor.spawn(io_task2).detach();
 802
 803        executor
 804            .spawn(async move {
 805                let future = server_incoming.next().await;
 806                let request = future
 807                    .unwrap()
 808                    .into_any()
 809                    .downcast::<TypedEnvelope<proto::Ping>>()
 810                    .unwrap();
 811
 812                server
 813                    .send(
 814                        server_to_client_conn_id,
 815                        ErrorCode::Internal
 816                            .message("message 1".to_string())
 817                            .to_proto(),
 818                    )
 819                    .unwrap();
 820                server
 821                    .send(
 822                        server_to_client_conn_id,
 823                        ErrorCode::Internal
 824                            .message("message 2".to_string())
 825                            .to_proto(),
 826                    )
 827                    .unwrap();
 828                server.respond(request.receipt(), proto::Ack {}).unwrap();
 829
 830                // Prevent the connection from being dropped
 831                server_incoming.next().await;
 832            })
 833            .detach();
 834
 835        let events = Arc::new(Mutex::new(Vec::new()));
 836
 837        let response = client.request(client_to_server_conn_id, proto::Ping {});
 838        let response_task = executor.spawn({
 839            let events = events.clone();
 840            async move {
 841                response.await.unwrap();
 842                events.lock().push("response".to_string());
 843            }
 844        });
 845
 846        executor
 847            .spawn({
 848                let events = events.clone();
 849                async move {
 850                    let incoming1 = client_incoming
 851                        .next()
 852                        .await
 853                        .unwrap()
 854                        .into_any()
 855                        .downcast::<TypedEnvelope<proto::Error>>()
 856                        .unwrap();
 857                    events.lock().push(incoming1.payload.message);
 858                    let incoming2 = client_incoming
 859                        .next()
 860                        .await
 861                        .unwrap()
 862                        .into_any()
 863                        .downcast::<TypedEnvelope<proto::Error>>()
 864                        .unwrap();
 865                    events.lock().push(incoming2.payload.message);
 866
 867                    // Prevent the connection from being dropped
 868                    client_incoming.next().await;
 869                }
 870            })
 871            .detach();
 872
 873        response_task.await;
 874        assert_eq!(
 875            &*events.lock(),
 876            &[
 877                "message 1".to_string(),
 878                "message 2".to_string(),
 879                "response".to_string()
 880            ]
 881        );
 882    }
 883
 884    #[gpui::test(iterations = 50)]
 885    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
 886        let executor = cx.executor();
 887        let server = Peer::new(0);
 888        let client = Peer::new(0);
 889
 890        let (client_to_server_conn, server_to_client_conn, _kill) =
 891            Connection::in_memory(cx.executor());
 892        let (client_to_server_conn_id, io_task1, mut client_incoming) =
 893            client.add_test_connection(client_to_server_conn, cx.executor());
 894        let (server_to_client_conn_id, io_task2, mut server_incoming) =
 895            server.add_test_connection(server_to_client_conn, cx.executor());
 896
 897        executor.spawn(io_task1).detach();
 898        executor.spawn(io_task2).detach();
 899
 900        executor
 901            .spawn(async move {
 902                let request1 = server_incoming
 903                    .next()
 904                    .await
 905                    .unwrap()
 906                    .into_any()
 907                    .downcast::<TypedEnvelope<proto::Ping>>()
 908                    .unwrap();
 909                let request2 = server_incoming
 910                    .next()
 911                    .await
 912                    .unwrap()
 913                    .into_any()
 914                    .downcast::<TypedEnvelope<proto::Ping>>()
 915                    .unwrap();
 916
 917                server
 918                    .send(
 919                        server_to_client_conn_id,
 920                        ErrorCode::Internal
 921                            .message("message 1".to_string())
 922                            .to_proto(),
 923                    )
 924                    .unwrap();
 925                server
 926                    .send(
 927                        server_to_client_conn_id,
 928                        ErrorCode::Internal
 929                            .message("message 2".to_string())
 930                            .to_proto(),
 931                    )
 932                    .unwrap();
 933                server.respond(request1.receipt(), proto::Ack {}).unwrap();
 934                server.respond(request2.receipt(), proto::Ack {}).unwrap();
 935
 936                // Prevent the connection from being dropped
 937                server_incoming.next().await;
 938            })
 939            .detach();
 940
 941        let events = Arc::new(Mutex::new(Vec::new()));
 942
 943        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
 944        let request1_task = executor.spawn(request1);
 945        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
 946        let request2_task = executor.spawn({
 947            let events = events.clone();
 948            async move {
 949                request2.await.unwrap();
 950                events.lock().push("response 2".to_string());
 951            }
 952        });
 953
 954        executor
 955            .spawn({
 956                let events = events.clone();
 957                async move {
 958                    let incoming1 = client_incoming
 959                        .next()
 960                        .await
 961                        .unwrap()
 962                        .into_any()
 963                        .downcast::<TypedEnvelope<proto::Error>>()
 964                        .unwrap();
 965                    events.lock().push(incoming1.payload.message);
 966                    let incoming2 = client_incoming
 967                        .next()
 968                        .await
 969                        .unwrap()
 970                        .into_any()
 971                        .downcast::<TypedEnvelope<proto::Error>>()
 972                        .unwrap();
 973                    events.lock().push(incoming2.payload.message);
 974
 975                    // Prevent the connection from being dropped
 976                    client_incoming.next().await;
 977                }
 978            })
 979            .detach();
 980
 981        // Allow the request to make some progress before dropping it.
 982        cx.executor().simulate_random_delay().await;
 983        drop(request1_task);
 984
 985        request2_task.await;
 986        assert_eq!(
 987            &*events.lock(),
 988            &[
 989                "message 1".to_string(),
 990                "message 2".to_string(),
 991                "response 2".to_string()
 992            ]
 993        );
 994    }
 995
 996    #[gpui::test(iterations = 50)]
 997    async fn test_disconnect(cx: &mut TestAppContext) {
 998        let executor = cx.executor();
 999
1000        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1001
1002        let client = Peer::new(0);
1003        let (connection_id, io_handler, mut incoming) =
1004            client.add_test_connection(client_conn, executor.clone());
1005
1006        let (io_ended_tx, io_ended_rx) = oneshot::channel();
1007        executor
1008            .spawn(async move {
1009                io_handler.await.ok();
1010                io_ended_tx.send(()).unwrap();
1011            })
1012            .detach();
1013
1014        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
1015        executor
1016            .spawn(async move {
1017                incoming.next().await;
1018                messages_ended_tx.send(()).unwrap();
1019            })
1020            .detach();
1021
1022        client.disconnect(connection_id);
1023
1024        let _ = io_ended_rx.await;
1025        let _ = messages_ended_rx.await;
1026        assert!(
1027            server_conn
1028                .send(WebSocketMessage::Binary(vec![].into()))
1029                .await
1030                .is_err()
1031        );
1032    }
1033
1034    #[gpui::test(iterations = 50)]
1035    async fn test_io_error(cx: &mut TestAppContext) {
1036        let executor = cx.executor();
1037        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1038
1039        let client = Peer::new(0);
1040        let (connection_id, io_handler, mut incoming) =
1041            client.add_test_connection(client_conn, executor.clone());
1042        executor.spawn(io_handler).detach();
1043        executor
1044            .spawn(async move { incoming.next().await })
1045            .detach();
1046
1047        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
1048        let _request = server_conn.rx.next().await.unwrap().unwrap();
1049
1050        drop(server_conn);
1051        assert_eq!(
1052            response.await.unwrap_err().to_string(),
1053            "connection was closed"
1054        );
1055    }
1056}