peer.rs

   1use super::{
   2    Connection,
   3    message_stream::{Message, MessageStream},
   4    proto::{
   5        self, AnyTypedEnvelope, EnvelopedMessage, PeerId, Receipt, RequestMessage, TypedEnvelope,
   6    },
   7};
   8use anyhow::{Context as _, Result, anyhow};
   9use collections::HashMap;
  10use futures::{
  11    FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
  12    channel::{mpsc, oneshot},
  13    stream::BoxStream,
  14};
  15use parking_lot::{Mutex, RwLock};
  16use proto::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
  17use serde::{Serialize, ser::SerializeStruct};
  18use std::{
  19    fmt, future,
  20    future::Future,
  21    sync::atomic::Ordering::SeqCst,
  22    sync::{
  23        Arc,
  24        atomic::{self, AtomicU32},
  25    },
  26    time::Duration,
  27    time::Instant,
  28};
  29use tracing::instrument;
  30
  31#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
  32pub struct ConnectionId {
  33    pub owner_id: u32,
  34    pub id: u32,
  35}
  36
  37impl From<ConnectionId> for PeerId {
  38    fn from(id: ConnectionId) -> Self {
  39        PeerId {
  40            owner_id: id.owner_id,
  41            id: id.id,
  42        }
  43    }
  44}
  45
  46impl From<PeerId> for ConnectionId {
  47    fn from(peer_id: PeerId) -> Self {
  48        Self {
  49            owner_id: peer_id.owner_id,
  50            id: peer_id.id,
  51        }
  52    }
  53}
  54
  55impl fmt::Display for ConnectionId {
  56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  57        write!(f, "{}/{}", self.owner_id, self.id)
  58    }
  59}
  60
  61pub struct Peer {
  62    epoch: AtomicU32,
  63    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
  64    next_connection_id: AtomicU32,
  65}
  66
  67#[derive(Clone, Serialize)]
  68pub struct ConnectionState {
  69    #[serde(skip)]
  70    outgoing_tx: mpsc::UnboundedSender<Message>,
  71    next_message_id: Arc<AtomicU32>,
  72    #[allow(clippy::type_complexity)]
  73    #[serde(skip)]
  74    response_channels: Arc<
  75        Mutex<
  76            Option<
  77                HashMap<
  78                    u32,
  79                    oneshot::Sender<(proto::Envelope, std::time::Instant, oneshot::Sender<()>)>,
  80                >,
  81            >,
  82        >,
  83    >,
  84    #[allow(clippy::type_complexity)]
  85    #[serde(skip)]
  86    stream_response_channels: Arc<
  87        Mutex<
  88            Option<
  89                HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
  90            >,
  91        >,
  92    >,
  93}
  94
  95const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
  96const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
  97pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
  98
  99impl Peer {
 100    pub fn new(epoch: u32) -> Arc<Self> {
 101        Arc::new(Self {
 102            epoch: AtomicU32::new(epoch),
 103            connections: Default::default(),
 104            next_connection_id: Default::default(),
 105        })
 106    }
 107
 108    pub fn epoch(&self) -> u32 {
 109        self.epoch.load(SeqCst)
 110    }
 111
 112    #[instrument(skip_all)]
 113    pub fn add_connection<F, Fut, Out>(
 114        self: &Arc<Self>,
 115        connection: Connection,
 116        create_timer: F,
 117    ) -> (
 118        ConnectionId,
 119        impl Future<Output = anyhow::Result<()>> + Send + use<F, Fut, Out>,
 120        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
 121    )
 122    where
 123        F: Send + Fn(Duration) -> Fut,
 124        Fut: Send + Future<Output = Out>,
 125        Out: Send,
 126    {
 127        // For outgoing messages, use an unbounded channel so that application code
 128        // can always send messages without yielding. For incoming messages, use a
 129        // bounded channel so that other peers will receive backpressure if they send
 130        // messages faster than this peer can process them.
 131        #[cfg(any(test, feature = "test-support"))]
 132        const INCOMING_BUFFER_SIZE: usize = 1;
 133        #[cfg(not(any(test, feature = "test-support")))]
 134        const INCOMING_BUFFER_SIZE: usize = 256;
 135        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
 136        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
 137
 138        let connection_id = ConnectionId {
 139            owner_id: self.epoch.load(SeqCst),
 140            id: self.next_connection_id.fetch_add(1, SeqCst),
 141        };
 142        let connection_state = ConnectionState {
 143            outgoing_tx,
 144            next_message_id: Default::default(),
 145            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
 146            stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
 147        };
 148        let mut writer = MessageStream::new(connection.tx);
 149        let mut reader = MessageStream::new(connection.rx);
 150
 151        let this = self.clone();
 152        let response_channels = connection_state.response_channels.clone();
 153        let stream_response_channels = connection_state.stream_response_channels.clone();
 154
 155        let handle_io = async move {
 156            tracing::trace!(%connection_id, "handle io future: start");
 157
 158            let _end_connection = util::defer(|| {
 159                response_channels.lock().take();
 160                if let Some(channels) = stream_response_channels.lock().take() {
 161                    for channel in channels.values() {
 162                        let _ = channel.unbounded_send((
 163                            Err(anyhow!("connection closed")),
 164                            oneshot::channel().0,
 165                        ));
 166                    }
 167                }
 168                this.connections.write().remove(&connection_id);
 169                tracing::trace!(%connection_id, "handle io future: end");
 170            });
 171
 172            // Send messages on this frequency so the connection isn't closed.
 173            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
 174            futures::pin_mut!(keepalive_timer);
 175
 176            // Disconnect if we don't receive messages at least this frequently.
 177            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
 178            futures::pin_mut!(receive_timeout);
 179
 180            loop {
 181                tracing::trace!(%connection_id, "outer loop iteration start");
 182                let read_message = reader.read().fuse();
 183                futures::pin_mut!(read_message);
 184
 185                loop {
 186                    tracing::trace!(%connection_id, "inner loop iteration start");
 187                    futures::select_biased! {
 188                        outgoing = outgoing_rx.next().fuse() => match outgoing {
 189                            Some(outgoing) => {
 190                                tracing::trace!(%connection_id, "outgoing rpc message: writing");
 191                                futures::select_biased! {
 192                                    result = writer.write(outgoing).fuse() => {
 193                                        tracing::trace!(%connection_id, "outgoing rpc message: done writing");
 194                                        result.context("failed to write RPC message")?;
 195                                        tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
 196                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
 197                                    }
 198                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
 199                                        tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
 200                                        anyhow::bail!("timed out writing message");
 201                                    }
 202                                }
 203                            }
 204                            None => {
 205                                tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
 206                                return Ok(())
 207                            },
 208                        },
 209                        _ = keepalive_timer => {
 210                            tracing::trace!(%connection_id, "keepalive interval: pinging");
 211                            futures::select_biased! {
 212                                result = writer.write(Message::Ping).fuse() => {
 213                                    tracing::trace!(%connection_id, "keepalive interval: done pinging");
 214                                    result.context("failed to send keepalive")?;
 215                                    tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
 216                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
 217                                }
 218                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
 219                                    tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
 220                                    anyhow::bail!("timed out sending keepalive");
 221                                }
 222                            }
 223                        }
 224                        incoming = read_message => {
 225                            let incoming = incoming.context("error reading rpc message from socket")?;
 226                            tracing::trace!(%connection_id, "incoming rpc message: received");
 227                            tracing::trace!(%connection_id, "receive timeout: resetting");
 228                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
 229                            if let (Message::Envelope(incoming), received_at) = incoming {
 230                                tracing::trace!(%connection_id, "incoming rpc message: processing");
 231                                futures::select_biased! {
 232                                    result = incoming_tx.send((incoming, received_at)).fuse() => match result {
 233                                        Ok(_) => {
 234                                            tracing::trace!(%connection_id, "incoming rpc message: processed");
 235                                        }
 236                                        Err(_) => {
 237                                            tracing::trace!(%connection_id, "incoming rpc message: channel closed");
 238                                            return Ok(())
 239                                        }
 240                                    },
 241                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
 242                                        tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
 243                                        anyhow::bail!("timed out processing incoming message");
 244                                    }
 245                                }
 246                            }
 247                            break;
 248                        },
 249                        _ = receive_timeout => {
 250                            tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
 251                            anyhow::bail!("delay between messages too long");
 252                        }
 253                    }
 254                }
 255            }
 256        };
 257
 258        let response_channels = connection_state.response_channels.clone();
 259        let stream_response_channels = connection_state.stream_response_channels.clone();
 260        self.connections
 261            .write()
 262            .insert(connection_id, connection_state);
 263
 264        let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
 265            let response_channels = response_channels.clone();
 266            let stream_response_channels = stream_response_channels.clone();
 267            async move {
 268                let message_id = incoming.id;
 269                tracing::trace!(?incoming, "incoming message future: start");
 270                let _end = util::defer(move || {
 271                    tracing::trace!(%connection_id, message_id, "incoming message future: end");
 272                });
 273
 274                if let Some(responding_to) = incoming.responding_to {
 275                    tracing::trace!(
 276                        %connection_id,
 277                        message_id,
 278                        responding_to,
 279                        "incoming response: received"
 280                    );
 281                    let response_channel =
 282                        response_channels.lock().as_mut()?.remove(&responding_to);
 283                    let stream_response_channel = stream_response_channels
 284                        .lock()
 285                        .as_ref()?
 286                        .get(&responding_to)
 287                        .cloned();
 288
 289                    if let Some(tx) = response_channel {
 290                        let requester_resumed = oneshot::channel();
 291                        if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
 292                            tracing::trace!(
 293                                %connection_id,
 294                                message_id,
 295                                responding_to = responding_to,
 296                                ?error,
 297                                "incoming response: request future dropped",
 298                            );
 299                        }
 300
 301                        tracing::trace!(
 302                            %connection_id,
 303                            message_id,
 304                            responding_to,
 305                            "incoming response: waiting to resume requester"
 306                        );
 307                        let _ = requester_resumed.1.await;
 308                        tracing::trace!(
 309                            %connection_id,
 310                            message_id,
 311                            responding_to,
 312                            "incoming response: requester resumed"
 313                        );
 314                    } else if let Some(tx) = stream_response_channel {
 315                        let requester_resumed = oneshot::channel();
 316                        if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
 317                            tracing::debug!(
 318                                %connection_id,
 319                                message_id,
 320                                responding_to = responding_to,
 321                                ?error,
 322                                "incoming stream response: request future dropped",
 323                            );
 324                        }
 325
 326                        tracing::debug!(
 327                            %connection_id,
 328                            message_id,
 329                            responding_to,
 330                            "incoming stream response: waiting to resume requester"
 331                        );
 332                        let _ = requester_resumed.1.await;
 333                        tracing::debug!(
 334                            %connection_id,
 335                            message_id,
 336                            responding_to,
 337                            "incoming stream response: requester resumed"
 338                        );
 339                    } else {
 340                        let message_type = proto::build_typed_envelope(
 341                            connection_id.into(),
 342                            received_at,
 343                            incoming,
 344                        )
 345                        .map(|p| p.payload_type_name());
 346                        tracing::warn!(
 347                            %connection_id,
 348                            message_id,
 349                            responding_to,
 350                            message_type,
 351                            "incoming response: unknown request"
 352                        );
 353                    }
 354
 355                    None
 356                } else {
 357                    tracing::trace!(%connection_id, message_id, "incoming message: received");
 358                    proto::build_typed_envelope(connection_id.into(), received_at, incoming)
 359                        .or_else(|| {
 360                            tracing::error!(
 361                                %connection_id,
 362                                message_id,
 363                                "unable to construct a typed envelope"
 364                            );
 365                            None
 366                        })
 367                }
 368            }
 369        });
 370        (connection_id, handle_io, incoming_rx.boxed())
 371    }
 372
 373    #[cfg(any(test, feature = "test-support"))]
 374    pub fn add_test_connection(
 375        self: &Arc<Self>,
 376        connection: Connection,
 377        executor: gpui::BackgroundExecutor,
 378    ) -> (
 379        ConnectionId,
 380        impl Future<Output = anyhow::Result<()>> + Send + use<>,
 381        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
 382    ) {
 383        let executor = executor.clone();
 384        self.add_connection(connection, move |duration| executor.timer(duration))
 385    }
 386
 387    pub fn disconnect(&self, connection_id: ConnectionId) {
 388        self.connections.write().remove(&connection_id);
 389    }
 390
 391    #[cfg(any(test, feature = "test-support"))]
 392    pub fn reset(&self, epoch: u32) {
 393        self.next_connection_id.store(0, SeqCst);
 394        self.epoch.store(epoch, SeqCst);
 395    }
 396
 397    pub fn teardown(&self) {
 398        self.connections.write().clear();
 399    }
 400
 401    /// Make a request and wait for a response.
 402    pub fn request<T: RequestMessage>(
 403        &self,
 404        receiver_id: ConnectionId,
 405        request: T,
 406    ) -> impl Future<Output = Result<T::Response>> + use<T> {
 407        self.request_internal(None, receiver_id, request)
 408            .map_ok(|envelope| envelope.payload)
 409    }
 410
 411    pub fn request_envelope<T: RequestMessage>(
 412        &self,
 413        receiver_id: ConnectionId,
 414        request: T,
 415    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
 416        self.request_internal(None, receiver_id, request)
 417    }
 418
 419    pub fn forward_request<T: RequestMessage>(
 420        &self,
 421        sender_id: ConnectionId,
 422        receiver_id: ConnectionId,
 423        request: T,
 424    ) -> impl Future<Output = Result<T::Response>> {
 425        let request_start_time = Instant::now();
 426        let payload_type = T::NAME;
 427        let elapsed_time = move || request_start_time.elapsed().as_millis();
 428        tracing::info!(payload_type, "start forwarding request");
 429        self.request_internal(Some(sender_id), receiver_id, request)
 430            .map_ok(|envelope| envelope.payload)
 431            .inspect_err(move |_| {
 432                tracing::error!(
 433                    waiting_for_host_ms = elapsed_time(),
 434                    payload_type,
 435                    "error forwarding request"
 436                )
 437            })
 438            .inspect_ok(move |_| {
 439                tracing::info!(
 440                    waiting_for_host_ms = elapsed_time(),
 441                    payload_type,
 442                    "finished forwarding request"
 443                )
 444            })
 445    }
 446
 447    fn request_internal<T: RequestMessage>(
 448        &self,
 449        original_sender_id: Option<ConnectionId>,
 450        receiver_id: ConnectionId,
 451        request: T,
 452    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
 453        let envelope = request.into_envelope(0, None, original_sender_id.map(Into::into));
 454        let response = self.request_dynamic(receiver_id, envelope, T::NAME);
 455        async move {
 456            let (response, received_at) = response.await?;
 457            Ok(TypedEnvelope {
 458                message_id: response.id,
 459                sender_id: receiver_id.into(),
 460                original_sender_id: response.original_sender_id,
 461                payload: T::Response::from_envelope(response)
 462                    .context("received response of the wrong type")?,
 463                received_at,
 464            })
 465        }
 466    }
 467
 468    /// Make a request and wait for a response.
 469    ///
 470    /// The caller must make sure to deserialize the response into the request's
 471    /// response type. This interface is only useful in trait objects, where
 472    /// generics can't be used. If you have a concrete type, use `request`.
 473    pub fn request_dynamic(
 474        &self,
 475        receiver_id: ConnectionId,
 476        mut envelope: proto::Envelope,
 477        type_name: &'static str,
 478    ) -> impl Future<Output = Result<(proto::Envelope, Instant)>> + use<> {
 479        let (tx, rx) = oneshot::channel();
 480        let send = self.connection_state(receiver_id).and_then(|connection| {
 481            envelope.id = connection.next_message_id.fetch_add(1, SeqCst);
 482            connection
 483                .response_channels
 484                .lock()
 485                .as_mut()
 486                .context("connection was closed")?
 487                .insert(envelope.id, tx);
 488            connection
 489                .outgoing_tx
 490                .unbounded_send(Message::Envelope(envelope))
 491                .context("connection was closed")?;
 492            Ok(())
 493        });
 494        async move {
 495            send?;
 496            let (response, received_at, _barrier) = rx.await.context("connection was closed")?;
 497            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
 498                return Err(RpcError::from_proto(error, type_name));
 499            }
 500            Ok((response, received_at))
 501        }
 502    }
 503
 504    pub fn request_stream<T: RequestMessage>(
 505        &self,
 506        receiver_id: ConnectionId,
 507        request: T,
 508    ) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
 509        let (tx, rx) = mpsc::unbounded();
 510        let send = self.connection_state(receiver_id).and_then(|connection| {
 511            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
 512            let stream_response_channels = connection.stream_response_channels.clone();
 513            stream_response_channels
 514                .lock()
 515                .as_mut()
 516                .context("connection was closed")?
 517                .insert(message_id, tx);
 518            connection
 519                .outgoing_tx
 520                .unbounded_send(Message::Envelope(
 521                    request.into_envelope(message_id, None, None),
 522                ))
 523                .context("connection was closed")?;
 524            Ok((message_id, stream_response_channels))
 525        });
 526
 527        async move {
 528            let (message_id, stream_response_channels) = send?;
 529            let stream_response_channels = Arc::downgrade(&stream_response_channels);
 530
 531            Ok(rx.filter_map(move |(response, _barrier)| {
 532                let stream_response_channels = stream_response_channels.clone();
 533                future::ready(match response {
 534                    Ok(response) => {
 535                        if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
 536                            Some(Err(RpcError::from_proto(error, T::NAME)))
 537                        } else if let Some(proto::envelope::Payload::EndStream(_)) =
 538                            &response.payload
 539                        {
 540                            // Remove the transmitting end of the response channel to end the stream.
 541                            if let Some(channels) = stream_response_channels.upgrade() {
 542                                if let Some(channels) = channels.lock().as_mut() {
 543                                    channels.remove(&message_id);
 544                                }
 545                            }
 546                            None
 547                        } else {
 548                            Some(
 549                                T::Response::from_envelope(response)
 550                                    .context("received response of the wrong type"),
 551                            )
 552                        }
 553                    }
 554                    Err(error) => Some(Err(error)),
 555                })
 556            }))
 557        }
 558    }
 559
 560    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
 561        let connection = self.connection_state(receiver_id)?;
 562        let message_id = connection
 563            .next_message_id
 564            .fetch_add(1, atomic::Ordering::SeqCst);
 565        connection.outgoing_tx.unbounded_send(Message::Envelope(
 566            message.into_envelope(message_id, None, None),
 567        ))?;
 568        Ok(())
 569    }
 570
 571    pub fn send_dynamic(&self, receiver_id: ConnectionId, message: proto::Envelope) -> Result<()> {
 572        let connection = self.connection_state(receiver_id)?;
 573        connection
 574            .outgoing_tx
 575            .unbounded_send(Message::Envelope(message))?;
 576        Ok(())
 577    }
 578
 579    pub fn forward_send<T: EnvelopedMessage>(
 580        &self,
 581        sender_id: ConnectionId,
 582        receiver_id: ConnectionId,
 583        message: T,
 584    ) -> Result<()> {
 585        let connection = self.connection_state(receiver_id)?;
 586        let message_id = connection
 587            .next_message_id
 588            .fetch_add(1, atomic::Ordering::SeqCst);
 589        connection
 590            .outgoing_tx
 591            .unbounded_send(Message::Envelope(message.into_envelope(
 592                message_id,
 593                None,
 594                Some(sender_id.into()),
 595            )))?;
 596        Ok(())
 597    }
 598
 599    pub fn respond<T: RequestMessage>(
 600        &self,
 601        receipt: Receipt<T>,
 602        response: T::Response,
 603    ) -> Result<()> {
 604        let connection = self.connection_state(receipt.sender_id.into())?;
 605        let message_id = connection
 606            .next_message_id
 607            .fetch_add(1, atomic::Ordering::SeqCst);
 608        connection
 609            .outgoing_tx
 610            .unbounded_send(Message::Envelope(response.into_envelope(
 611                message_id,
 612                Some(receipt.message_id),
 613                None,
 614            )))?;
 615        Ok(())
 616    }
 617
 618    pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
 619        let connection = self.connection_state(receipt.sender_id.into())?;
 620        let message_id = connection
 621            .next_message_id
 622            .fetch_add(1, atomic::Ordering::SeqCst);
 623
 624        let message = proto::EndStream {};
 625
 626        connection
 627            .outgoing_tx
 628            .unbounded_send(Message::Envelope(message.into_envelope(
 629                message_id,
 630                Some(receipt.message_id),
 631                None,
 632            )))?;
 633        Ok(())
 634    }
 635
 636    pub fn respond_with_error<T: RequestMessage>(
 637        &self,
 638        receipt: Receipt<T>,
 639        response: proto::Error,
 640    ) -> Result<()> {
 641        let connection = self.connection_state(receipt.sender_id.into())?;
 642        let message_id = connection
 643            .next_message_id
 644            .fetch_add(1, atomic::Ordering::SeqCst);
 645        connection
 646            .outgoing_tx
 647            .unbounded_send(Message::Envelope(response.into_envelope(
 648                message_id,
 649                Some(receipt.message_id),
 650                None,
 651            )))?;
 652        Ok(())
 653    }
 654
 655    pub fn respond_with_unhandled_message(
 656        &self,
 657        sender_id: ConnectionId,
 658        request_message_id: u32,
 659        message_type_name: &'static str,
 660    ) -> Result<()> {
 661        let connection = self.connection_state(sender_id)?;
 662        let response = ErrorCode::Internal
 663            .message(format!("message {} was not handled", message_type_name))
 664            .to_proto();
 665        let message_id = connection
 666            .next_message_id
 667            .fetch_add(1, atomic::Ordering::SeqCst);
 668        connection
 669            .outgoing_tx
 670            .unbounded_send(Message::Envelope(response.into_envelope(
 671                message_id,
 672                Some(request_message_id),
 673                None,
 674            )))?;
 675        Ok(())
 676    }
 677
 678    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
 679        let connections = self.connections.read();
 680        let connection = connections
 681            .get(&connection_id)
 682            .with_context(|| format!("no such connection: {connection_id}"))?;
 683        Ok(connection.clone())
 684    }
 685}
 686
 687impl Serialize for Peer {
 688    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 689    where
 690        S: serde::Serializer,
 691    {
 692        let mut state = serializer.serialize_struct("Peer", 2)?;
 693        state.serialize_field("connections", &*self.connections.read())?;
 694        state.end()
 695    }
 696}
 697
 698#[cfg(test)]
 699mod tests {
 700    use super::*;
 701    use async_tungstenite::tungstenite::Message as WebSocketMessage;
 702    use gpui::TestAppContext;
 703
 704    fn init_logger() {
 705        zlog::init_test();
 706    }
 707
 708    #[gpui::test(iterations = 50)]
 709    async fn test_request_response(cx: &mut TestAppContext) {
 710        init_logger();
 711
 712        let executor = cx.executor();
 713
 714        // create 2 clients connected to 1 server
 715        let server = Peer::new(0);
 716        let client1 = Peer::new(0);
 717        let client2 = Peer::new(0);
 718
 719        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
 720            Connection::in_memory(cx.executor());
 721        let (client1_conn_id, io_task1, client1_incoming) =
 722            client1.add_test_connection(client1_to_server_conn, cx.executor());
 723        let (_, io_task2, server_incoming1) =
 724            server.add_test_connection(server_to_client_1_conn, cx.executor());
 725
 726        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
 727            Connection::in_memory(cx.executor());
 728        let (client2_conn_id, io_task3, client2_incoming) =
 729            client2.add_test_connection(client2_to_server_conn, cx.executor());
 730        let (_, io_task4, server_incoming2) =
 731            server.add_test_connection(server_to_client_2_conn, cx.executor());
 732
 733        executor.spawn(io_task1).detach();
 734        executor.spawn(io_task2).detach();
 735        executor.spawn(io_task3).detach();
 736        executor.spawn(io_task4).detach();
 737        executor
 738            .spawn(handle_messages(server_incoming1, server.clone()))
 739            .detach();
 740        executor
 741            .spawn(handle_messages(client1_incoming, client1.clone()))
 742            .detach();
 743        executor
 744            .spawn(handle_messages(server_incoming2, server.clone()))
 745            .detach();
 746        executor
 747            .spawn(handle_messages(client2_incoming, client2.clone()))
 748            .detach();
 749
 750        assert_eq!(
 751            client1
 752                .request(client1_conn_id, proto::Ping {},)
 753                .await
 754                .unwrap(),
 755            proto::Ack {}
 756        );
 757
 758        assert_eq!(
 759            client2
 760                .request(client2_conn_id, proto::Ping {},)
 761                .await
 762                .unwrap(),
 763            proto::Ack {}
 764        );
 765
 766        assert_eq!(
 767            client1
 768                .request(client1_conn_id, proto::Test { id: 1 },)
 769                .await
 770                .unwrap(),
 771            proto::Test { id: 1 }
 772        );
 773
 774        assert_eq!(
 775            client2
 776                .request(client2_conn_id, proto::Test { id: 2 })
 777                .await
 778                .unwrap(),
 779            proto::Test { id: 2 }
 780        );
 781
 782        client1.disconnect(client1_conn_id);
 783        client2.disconnect(client1_conn_id);
 784
 785        async fn handle_messages(
 786            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
 787            peer: Arc<Peer>,
 788        ) -> Result<()> {
 789            while let Some(envelope) = messages.next().await {
 790                let envelope = envelope.into_any();
 791                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
 792                    let receipt = envelope.receipt();
 793                    peer.respond(receipt, proto::Ack {})?
 794                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
 795                {
 796                    peer.respond(envelope.receipt(), envelope.payload.clone())?
 797                } else {
 798                    panic!("unknown message type");
 799                }
 800            }
 801
 802            Ok(())
 803        }
 804    }
 805
 806    #[gpui::test(iterations = 50)]
 807    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
 808        let executor = cx.executor();
 809        let server = Peer::new(0);
 810        let client = Peer::new(0);
 811
 812        let (client_to_server_conn, server_to_client_conn, _kill) =
 813            Connection::in_memory(executor.clone());
 814        let (client_to_server_conn_id, io_task1, mut client_incoming) =
 815            client.add_test_connection(client_to_server_conn, executor.clone());
 816
 817        let (server_to_client_conn_id, io_task2, mut server_incoming) =
 818            server.add_test_connection(server_to_client_conn, executor.clone());
 819
 820        executor.spawn(io_task1).detach();
 821        executor.spawn(io_task2).detach();
 822
 823        executor
 824            .spawn(async move {
 825                let future = server_incoming.next().await;
 826                let request = future
 827                    .unwrap()
 828                    .into_any()
 829                    .downcast::<TypedEnvelope<proto::Ping>>()
 830                    .unwrap();
 831
 832                server
 833                    .send(
 834                        server_to_client_conn_id,
 835                        ErrorCode::Internal
 836                            .message("message 1".to_string())
 837                            .to_proto(),
 838                    )
 839                    .unwrap();
 840                server
 841                    .send(
 842                        server_to_client_conn_id,
 843                        ErrorCode::Internal
 844                            .message("message 2".to_string())
 845                            .to_proto(),
 846                    )
 847                    .unwrap();
 848                server.respond(request.receipt(), proto::Ack {}).unwrap();
 849
 850                // Prevent the connection from being dropped
 851                server_incoming.next().await;
 852            })
 853            .detach();
 854
 855        let events = Arc::new(Mutex::new(Vec::new()));
 856
 857        let response = client.request(client_to_server_conn_id, proto::Ping {});
 858        let response_task = executor.spawn({
 859            let events = events.clone();
 860            async move {
 861                response.await.unwrap();
 862                events.lock().push("response".to_string());
 863            }
 864        });
 865
 866        executor
 867            .spawn({
 868                let events = events.clone();
 869                async move {
 870                    let incoming1 = client_incoming
 871                        .next()
 872                        .await
 873                        .unwrap()
 874                        .into_any()
 875                        .downcast::<TypedEnvelope<proto::Error>>()
 876                        .unwrap();
 877                    events.lock().push(incoming1.payload.message);
 878                    let incoming2 = client_incoming
 879                        .next()
 880                        .await
 881                        .unwrap()
 882                        .into_any()
 883                        .downcast::<TypedEnvelope<proto::Error>>()
 884                        .unwrap();
 885                    events.lock().push(incoming2.payload.message);
 886
 887                    // Prevent the connection from being dropped
 888                    client_incoming.next().await;
 889                }
 890            })
 891            .detach();
 892
 893        response_task.await;
 894        assert_eq!(
 895            &*events.lock(),
 896            &[
 897                "message 1".to_string(),
 898                "message 2".to_string(),
 899                "response".to_string()
 900            ]
 901        );
 902    }
 903
 904    #[gpui::test(iterations = 50)]
 905    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
 906        let executor = cx.executor();
 907        let server = Peer::new(0);
 908        let client = Peer::new(0);
 909
 910        let (client_to_server_conn, server_to_client_conn, _kill) =
 911            Connection::in_memory(cx.executor());
 912        let (client_to_server_conn_id, io_task1, mut client_incoming) =
 913            client.add_test_connection(client_to_server_conn, cx.executor());
 914        let (server_to_client_conn_id, io_task2, mut server_incoming) =
 915            server.add_test_connection(server_to_client_conn, cx.executor());
 916
 917        executor.spawn(io_task1).detach();
 918        executor.spawn(io_task2).detach();
 919
 920        executor
 921            .spawn(async move {
 922                let request1 = server_incoming
 923                    .next()
 924                    .await
 925                    .unwrap()
 926                    .into_any()
 927                    .downcast::<TypedEnvelope<proto::Ping>>()
 928                    .unwrap();
 929                let request2 = server_incoming
 930                    .next()
 931                    .await
 932                    .unwrap()
 933                    .into_any()
 934                    .downcast::<TypedEnvelope<proto::Ping>>()
 935                    .unwrap();
 936
 937                server
 938                    .send(
 939                        server_to_client_conn_id,
 940                        ErrorCode::Internal
 941                            .message("message 1".to_string())
 942                            .to_proto(),
 943                    )
 944                    .unwrap();
 945                server
 946                    .send(
 947                        server_to_client_conn_id,
 948                        ErrorCode::Internal
 949                            .message("message 2".to_string())
 950                            .to_proto(),
 951                    )
 952                    .unwrap();
 953                server.respond(request1.receipt(), proto::Ack {}).unwrap();
 954                server.respond(request2.receipt(), proto::Ack {}).unwrap();
 955
 956                // Prevent the connection from being dropped
 957                server_incoming.next().await;
 958            })
 959            .detach();
 960
 961        let events = Arc::new(Mutex::new(Vec::new()));
 962
 963        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
 964        let request1_task = executor.spawn(request1);
 965        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
 966        let request2_task = executor.spawn({
 967            let events = events.clone();
 968            async move {
 969                request2.await.unwrap();
 970                events.lock().push("response 2".to_string());
 971            }
 972        });
 973
 974        executor
 975            .spawn({
 976                let events = events.clone();
 977                async move {
 978                    let incoming1 = client_incoming
 979                        .next()
 980                        .await
 981                        .unwrap()
 982                        .into_any()
 983                        .downcast::<TypedEnvelope<proto::Error>>()
 984                        .unwrap();
 985                    events.lock().push(incoming1.payload.message);
 986                    let incoming2 = client_incoming
 987                        .next()
 988                        .await
 989                        .unwrap()
 990                        .into_any()
 991                        .downcast::<TypedEnvelope<proto::Error>>()
 992                        .unwrap();
 993                    events.lock().push(incoming2.payload.message);
 994
 995                    // Prevent the connection from being dropped
 996                    client_incoming.next().await;
 997                }
 998            })
 999            .detach();
1000
1001        // Allow the request to make some progress before dropping it.
1002        cx.executor().simulate_random_delay().await;
1003        drop(request1_task);
1004
1005        request2_task.await;
1006        assert_eq!(
1007            &*events.lock(),
1008            &[
1009                "message 1".to_string(),
1010                "message 2".to_string(),
1011                "response 2".to_string()
1012            ]
1013        );
1014    }
1015
1016    #[gpui::test(iterations = 50)]
1017    async fn test_disconnect(cx: &mut TestAppContext) {
1018        let executor = cx.executor();
1019
1020        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1021
1022        let client = Peer::new(0);
1023        let (connection_id, io_handler, mut incoming) =
1024            client.add_test_connection(client_conn, executor.clone());
1025
1026        let (io_ended_tx, io_ended_rx) = oneshot::channel();
1027        executor
1028            .spawn(async move {
1029                io_handler.await.ok();
1030                io_ended_tx.send(()).unwrap();
1031            })
1032            .detach();
1033
1034        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
1035        executor
1036            .spawn(async move {
1037                incoming.next().await;
1038                messages_ended_tx.send(()).unwrap();
1039            })
1040            .detach();
1041
1042        client.disconnect(connection_id);
1043
1044        let _ = io_ended_rx.await;
1045        let _ = messages_ended_rx.await;
1046        assert!(
1047            server_conn
1048                .send(WebSocketMessage::Binary(vec![].into()))
1049                .await
1050                .is_err()
1051        );
1052    }
1053
1054    #[gpui::test(iterations = 50)]
1055    async fn test_io_error(cx: &mut TestAppContext) {
1056        let executor = cx.executor();
1057        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1058
1059        let client = Peer::new(0);
1060        let (connection_id, io_handler, mut incoming) =
1061            client.add_test_connection(client_conn, executor.clone());
1062        executor.spawn(io_handler).detach();
1063        executor
1064            .spawn(async move { incoming.next().await })
1065            .detach();
1066
1067        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
1068        let _request = server_conn.rx.next().await.unwrap().unwrap();
1069
1070        drop(server_conn);
1071        assert_eq!(
1072            response.await.unwrap_err().to_string(),
1073            "connection was closed"
1074        );
1075    }
1076}