peer.rs

   1use crate::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
   2
   3use super::{
   4    proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
   5    Connection,
   6};
   7use anyhow::{anyhow, Context, Result};
   8use collections::HashMap;
   9use futures::{
  10    channel::{mpsc, oneshot},
  11    stream::BoxStream,
  12    FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
  13};
  14use parking_lot::{Mutex, RwLock};
  15use serde::{ser::SerializeStruct, Serialize};
  16use std::{
  17    fmt, future,
  18    future::Future,
  19    marker::PhantomData,
  20    sync::atomic::Ordering::SeqCst,
  21    sync::{
  22        atomic::{self, AtomicU32},
  23        Arc,
  24    },
  25    time::Duration,
  26    time::Instant,
  27};
  28use tracing::instrument;
  29
  30#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
  31pub struct ConnectionId {
  32    pub owner_id: u32,
  33    pub id: u32,
  34}
  35
  36impl Into<PeerId> for ConnectionId {
  37    fn into(self) -> PeerId {
  38        PeerId {
  39            owner_id: self.owner_id,
  40            id: self.id,
  41        }
  42    }
  43}
  44
  45impl From<PeerId> for ConnectionId {
  46    fn from(peer_id: PeerId) -> Self {
  47        Self {
  48            owner_id: peer_id.owner_id,
  49            id: peer_id.id,
  50        }
  51    }
  52}
  53
  54impl fmt::Display for ConnectionId {
  55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  56        write!(f, "{}/{}", self.owner_id, self.id)
  57    }
  58}
  59
  60pub struct Receipt<T> {
  61    pub sender_id: ConnectionId,
  62    pub message_id: u32,
  63    payload_type: PhantomData<T>,
  64}
  65
  66impl<T> Clone for Receipt<T> {
  67    fn clone(&self) -> Self {
  68        *self
  69    }
  70}
  71
  72impl<T> Copy for Receipt<T> {}
  73
  74#[derive(Clone, Debug)]
  75pub struct TypedEnvelope<T> {
  76    pub sender_id: ConnectionId,
  77    pub original_sender_id: Option<PeerId>,
  78    pub message_id: u32,
  79    pub payload: T,
  80    pub received_at: Instant,
  81}
  82
  83impl<T> TypedEnvelope<T> {
  84    pub fn original_sender_id(&self) -> Result<PeerId> {
  85        self.original_sender_id
  86            .ok_or_else(|| anyhow!("missing original_sender_id"))
  87    }
  88}
  89
  90impl<T: RequestMessage> TypedEnvelope<T> {
  91    pub fn receipt(&self) -> Receipt<T> {
  92        Receipt {
  93            sender_id: self.sender_id,
  94            message_id: self.message_id,
  95            payload_type: PhantomData,
  96        }
  97    }
  98}
  99
 100pub struct Peer {
 101    epoch: AtomicU32,
 102    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 103    next_connection_id: AtomicU32,
 104}
 105
 106#[derive(Clone, Serialize)]
 107pub struct ConnectionState {
 108    #[serde(skip)]
 109    outgoing_tx: mpsc::UnboundedSender<proto::Message>,
 110    next_message_id: Arc<AtomicU32>,
 111    #[allow(clippy::type_complexity)]
 112    #[serde(skip)]
 113    response_channels: Arc<
 114        Mutex<
 115            Option<
 116                HashMap<
 117                    u32,
 118                    oneshot::Sender<(proto::Envelope, std::time::Instant, oneshot::Sender<()>)>,
 119                >,
 120            >,
 121        >,
 122    >,
 123    #[allow(clippy::type_complexity)]
 124    #[serde(skip)]
 125    stream_response_channels: Arc<
 126        Mutex<
 127            Option<
 128                HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
 129            >,
 130        >,
 131    >,
 132}
 133
 134const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
 135const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
 136pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
 137
 138impl Peer {
 139    pub fn new(epoch: u32) -> Arc<Self> {
 140        Arc::new(Self {
 141            epoch: AtomicU32::new(epoch),
 142            connections: Default::default(),
 143            next_connection_id: Default::default(),
 144        })
 145    }
 146
 147    pub fn epoch(&self) -> u32 {
 148        self.epoch.load(SeqCst)
 149    }
 150
 151    #[instrument(skip_all)]
 152    pub fn add_connection<F, Fut, Out>(
 153        self: &Arc<Self>,
 154        connection: Connection,
 155        create_timer: F,
 156    ) -> (
 157        ConnectionId,
 158        impl Future<Output = anyhow::Result<()>> + Send,
 159        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
 160    )
 161    where
 162        F: Send + Fn(Duration) -> Fut,
 163        Fut: Send + Future<Output = Out>,
 164        Out: Send,
 165    {
 166        // For outgoing messages, use an unbounded channel so that application code
 167        // can always send messages without yielding. For incoming messages, use a
 168        // bounded channel so that other peers will receive backpressure if they send
 169        // messages faster than this peer can process them.
 170        #[cfg(any(test, feature = "test-support"))]
 171        const INCOMING_BUFFER_SIZE: usize = 1;
 172        #[cfg(not(any(test, feature = "test-support")))]
 173        const INCOMING_BUFFER_SIZE: usize = 256;
 174        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
 175        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
 176
 177        let connection_id = ConnectionId {
 178            owner_id: self.epoch.load(SeqCst),
 179            id: self.next_connection_id.fetch_add(1, SeqCst),
 180        };
 181        let connection_state = ConnectionState {
 182            outgoing_tx,
 183            next_message_id: Default::default(),
 184            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
 185            stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
 186        };
 187        let mut writer = MessageStream::new(connection.tx);
 188        let mut reader = MessageStream::new(connection.rx);
 189
 190        let this = self.clone();
 191        let response_channels = connection_state.response_channels.clone();
 192        let stream_response_channels = connection_state.stream_response_channels.clone();
 193
 194        let handle_io = async move {
 195            tracing::trace!(%connection_id, "handle io future: start");
 196
 197            let _end_connection = util::defer(|| {
 198                response_channels.lock().take();
 199                if let Some(channels) = stream_response_channels.lock().take() {
 200                    for channel in channels.values() {
 201                        let _ = channel.unbounded_send((
 202                            Err(anyhow!("connection closed")),
 203                            oneshot::channel().0,
 204                        ));
 205                    }
 206                }
 207                this.connections.write().remove(&connection_id);
 208                tracing::trace!(%connection_id, "handle io future: end");
 209            });
 210
 211            // Send messages on this frequency so the connection isn't closed.
 212            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
 213            futures::pin_mut!(keepalive_timer);
 214
 215            // Disconnect if we don't receive messages at least this frequently.
 216            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
 217            futures::pin_mut!(receive_timeout);
 218
 219            loop {
 220                tracing::trace!(%connection_id, "outer loop iteration start");
 221                let read_message = reader.read().fuse();
 222                futures::pin_mut!(read_message);
 223
 224                loop {
 225                    tracing::trace!(%connection_id, "inner loop iteration start");
 226                    futures::select_biased! {
 227                        outgoing = outgoing_rx.next().fuse() => match outgoing {
 228                            Some(outgoing) => {
 229                                tracing::trace!(%connection_id, "outgoing rpc message: writing");
 230                                futures::select_biased! {
 231                                    result = writer.write(outgoing).fuse() => {
 232                                        tracing::trace!(%connection_id, "outgoing rpc message: done writing");
 233                                        result.context("failed to write RPC message")?;
 234                                        tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
 235                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
 236                                    }
 237                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
 238                                        tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
 239                                        Err(anyhow!("timed out writing message"))?;
 240                                    }
 241                                }
 242                            }
 243                            None => {
 244                                tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
 245                                return Ok(())
 246                            },
 247                        },
 248                        _ = keepalive_timer => {
 249                            tracing::trace!(%connection_id, "keepalive interval: pinging");
 250                            futures::select_biased! {
 251                                result = writer.write(proto::Message::Ping).fuse() => {
 252                                    tracing::trace!(%connection_id, "keepalive interval: done pinging");
 253                                    result.context("failed to send keepalive")?;
 254                                    tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
 255                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
 256                                }
 257                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
 258                                    tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
 259                                    Err(anyhow!("timed out sending keepalive"))?;
 260                                }
 261                            }
 262                        }
 263                        incoming = read_message => {
 264                            let incoming = incoming.context("error reading rpc message from socket")?;
 265                            tracing::trace!(%connection_id, "incoming rpc message: received");
 266                            tracing::trace!(%connection_id, "receive timeout: resetting");
 267                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
 268                            if let (proto::Message::Envelope(incoming), received_at) = incoming {
 269                                tracing::trace!(%connection_id, "incoming rpc message: processing");
 270                                futures::select_biased! {
 271                                    result = incoming_tx.send((incoming, received_at)).fuse() => match result {
 272                                        Ok(_) => {
 273                                            tracing::trace!(%connection_id, "incoming rpc message: processed");
 274                                        }
 275                                        Err(_) => {
 276                                            tracing::trace!(%connection_id, "incoming rpc message: channel closed");
 277                                            return Ok(())
 278                                        }
 279                                    },
 280                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
 281                                        tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
 282                                        Err(anyhow!("timed out processing incoming message"))?
 283                                    }
 284                                }
 285                            }
 286                            break;
 287                        },
 288                        _ = receive_timeout => {
 289                            tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
 290                            Err(anyhow!("delay between messages too long"))?
 291                        }
 292                    }
 293                }
 294            }
 295        };
 296
 297        let response_channels = connection_state.response_channels.clone();
 298        let stream_response_channels = connection_state.stream_response_channels.clone();
 299        self.connections
 300            .write()
 301            .insert(connection_id, connection_state);
 302
 303        let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
 304            let response_channels = response_channels.clone();
 305            let stream_response_channels = stream_response_channels.clone();
 306            async move {
 307                let message_id = incoming.id;
 308                tracing::trace!(?incoming, "incoming message future: start");
 309                let _end = util::defer(move || {
 310                    tracing::trace!(%connection_id, message_id, "incoming message future: end");
 311                });
 312
 313                if let Some(responding_to) = incoming.responding_to {
 314                    tracing::trace!(
 315                        %connection_id,
 316                        message_id,
 317                        responding_to,
 318                        "incoming response: received"
 319                    );
 320                    let response_channel =
 321                        response_channels.lock().as_mut()?.remove(&responding_to);
 322                    let stream_response_channel = stream_response_channels
 323                        .lock()
 324                        .as_ref()?
 325                        .get(&responding_to)
 326                        .cloned();
 327
 328                    if let Some(tx) = response_channel {
 329                        let requester_resumed = oneshot::channel();
 330                        if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
 331                            tracing::trace!(
 332                                %connection_id,
 333                                message_id,
 334                                responding_to = responding_to,
 335                                ?error,
 336                                "incoming response: request future dropped",
 337                            );
 338                        }
 339
 340                        tracing::trace!(
 341                            %connection_id,
 342                            message_id,
 343                            responding_to,
 344                            "incoming response: waiting to resume requester"
 345                        );
 346                        let _ = requester_resumed.1.await;
 347                        tracing::trace!(
 348                            %connection_id,
 349                            message_id,
 350                            responding_to,
 351                            "incoming response: requester resumed"
 352                        );
 353                    } else if let Some(tx) = stream_response_channel {
 354                        let requester_resumed = oneshot::channel();
 355                        if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
 356                            tracing::debug!(
 357                                %connection_id,
 358                                message_id,
 359                                responding_to = responding_to,
 360                                ?error,
 361                                "incoming stream response: request future dropped",
 362                            );
 363                        }
 364
 365                        tracing::debug!(
 366                            %connection_id,
 367                            message_id,
 368                            responding_to,
 369                            "incoming stream response: waiting to resume requester"
 370                        );
 371                        let _ = requester_resumed.1.await;
 372                        tracing::debug!(
 373                            %connection_id,
 374                            message_id,
 375                            responding_to,
 376                            "incoming stream response: requester resumed"
 377                        );
 378                    } else {
 379                        let message_type =
 380                            proto::build_typed_envelope(connection_id, received_at, incoming)
 381                                .map(|p| p.payload_type_name());
 382                        tracing::warn!(
 383                            %connection_id,
 384                            message_id,
 385                            responding_to,
 386                            message_type,
 387                            "incoming response: unknown request"
 388                        );
 389                    }
 390
 391                    None
 392                } else {
 393                    tracing::trace!(%connection_id, message_id, "incoming message: received");
 394                    proto::build_typed_envelope(connection_id, received_at, incoming).or_else(
 395                        || {
 396                            tracing::error!(
 397                                %connection_id,
 398                                message_id,
 399                                "unable to construct a typed envelope"
 400                            );
 401                            None
 402                        },
 403                    )
 404                }
 405            }
 406        });
 407        (connection_id, handle_io, incoming_rx.boxed())
 408    }
 409
 410    #[cfg(any(test, feature = "test-support"))]
 411    pub fn add_test_connection(
 412        self: &Arc<Self>,
 413        connection: Connection,
 414        executor: gpui::BackgroundExecutor,
 415    ) -> (
 416        ConnectionId,
 417        impl Future<Output = anyhow::Result<()>> + Send,
 418        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
 419    ) {
 420        let executor = executor.clone();
 421        self.add_connection(connection, move |duration| executor.timer(duration))
 422    }
 423
 424    pub fn disconnect(&self, connection_id: ConnectionId) {
 425        self.connections.write().remove(&connection_id);
 426    }
 427
 428    #[cfg(any(test, feature = "test-support"))]
 429    pub fn reset(&self, epoch: u32) {
 430        self.next_connection_id.store(0, SeqCst);
 431        self.epoch.store(epoch, SeqCst);
 432    }
 433
 434    pub fn teardown(&self) {
 435        self.connections.write().clear();
 436    }
 437
 438    /// Make a request and wait for a response.
 439    pub fn request<T: RequestMessage>(
 440        &self,
 441        receiver_id: ConnectionId,
 442        request: T,
 443    ) -> impl Future<Output = Result<T::Response>> {
 444        self.request_internal(None, receiver_id, request)
 445            .map_ok(|envelope| envelope.payload)
 446    }
 447
 448    pub fn request_envelope<T: RequestMessage>(
 449        &self,
 450        receiver_id: ConnectionId,
 451        request: T,
 452    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
 453        self.request_internal(None, receiver_id, request)
 454    }
 455
 456    pub fn forward_request<T: RequestMessage>(
 457        &self,
 458        sender_id: ConnectionId,
 459        receiver_id: ConnectionId,
 460        request: T,
 461    ) -> impl Future<Output = Result<T::Response>> {
 462        self.request_internal(Some(sender_id), receiver_id, request)
 463            .map_ok(|envelope| envelope.payload)
 464    }
 465
 466    fn request_internal<T: RequestMessage>(
 467        &self,
 468        original_sender_id: Option<ConnectionId>,
 469        receiver_id: ConnectionId,
 470        request: T,
 471    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
 472        let envelope = request.into_envelope(0, None, original_sender_id.map(Into::into));
 473        let response = self.request_dynamic(receiver_id, envelope, T::NAME);
 474        async move {
 475            let (response, received_at) = response.await?;
 476            Ok(TypedEnvelope {
 477                message_id: response.id,
 478                sender_id: receiver_id,
 479                original_sender_id: response.original_sender_id,
 480                payload: T::Response::from_envelope(response)
 481                    .ok_or_else(|| anyhow!("received response of the wrong type"))?,
 482                received_at,
 483            })
 484        }
 485    }
 486
 487    /// Make a request and wait for a response.
 488    ///
 489    /// The caller must make sure to deserialize the response into the request's
 490    /// response type. This interface is only useful in trait objects, where
 491    /// generics can't be used. If you have a concrete type, use `request`.
 492    pub fn request_dynamic(
 493        &self,
 494        receiver_id: ConnectionId,
 495        mut envelope: proto::Envelope,
 496        type_name: &'static str,
 497    ) -> impl Future<Output = Result<(proto::Envelope, Instant)>> {
 498        let (tx, rx) = oneshot::channel();
 499        let send = self.connection_state(receiver_id).and_then(|connection| {
 500            envelope.id = connection.next_message_id.fetch_add(1, SeqCst);
 501            connection
 502                .response_channels
 503                .lock()
 504                .as_mut()
 505                .ok_or_else(|| anyhow!("connection was closed"))?
 506                .insert(envelope.id, tx);
 507            connection
 508                .outgoing_tx
 509                .unbounded_send(proto::Message::Envelope(envelope))
 510                .map_err(|_| anyhow!("connection was closed"))?;
 511            Ok(())
 512        });
 513        async move {
 514            send?;
 515            let (response, received_at, _barrier) =
 516                rx.await.map_err(|_| anyhow!("connection was closed"))?;
 517            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
 518                return Err(RpcError::from_proto(&error, type_name));
 519            }
 520            Ok((response, received_at))
 521        }
 522    }
 523
 524    pub fn request_stream<T: RequestMessage>(
 525        &self,
 526        receiver_id: ConnectionId,
 527        request: T,
 528    ) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
 529        let (tx, rx) = mpsc::unbounded();
 530        let send = self.connection_state(receiver_id).and_then(|connection| {
 531            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
 532            let stream_response_channels = connection.stream_response_channels.clone();
 533            stream_response_channels
 534                .lock()
 535                .as_mut()
 536                .ok_or_else(|| anyhow!("connection was closed"))?
 537                .insert(message_id, tx);
 538            connection
 539                .outgoing_tx
 540                .unbounded_send(proto::Message::Envelope(
 541                    request.into_envelope(message_id, None, None),
 542                ))
 543                .map_err(|_| anyhow!("connection was closed"))?;
 544            Ok((message_id, stream_response_channels))
 545        });
 546
 547        async move {
 548            let (message_id, stream_response_channels) = send?;
 549            let stream_response_channels = Arc::downgrade(&stream_response_channels);
 550
 551            Ok(rx.filter_map(move |(response, _barrier)| {
 552                let stream_response_channels = stream_response_channels.clone();
 553                future::ready(match response {
 554                    Ok(response) => {
 555                        if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
 556                            Some(Err(anyhow!(
 557                                "RPC request {} failed - {}",
 558                                T::NAME,
 559                                error.message
 560                            )))
 561                        } else if let Some(proto::envelope::Payload::EndStream(_)) =
 562                            &response.payload
 563                        {
 564                            // Remove the transmitting end of the response channel to end the stream.
 565                            if let Some(channels) = stream_response_channels.upgrade() {
 566                                if let Some(channels) = channels.lock().as_mut() {
 567                                    channels.remove(&message_id);
 568                                }
 569                            }
 570                            None
 571                        } else {
 572                            Some(
 573                                T::Response::from_envelope(response)
 574                                    .ok_or_else(|| anyhow!("received response of the wrong type")),
 575                            )
 576                        }
 577                    }
 578                    Err(error) => Some(Err(error)),
 579                })
 580            }))
 581        }
 582    }
 583
 584    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
 585        let connection = self.connection_state(receiver_id)?;
 586        let message_id = connection
 587            .next_message_id
 588            .fetch_add(1, atomic::Ordering::SeqCst);
 589        connection
 590            .outgoing_tx
 591            .unbounded_send(proto::Message::Envelope(
 592                message.into_envelope(message_id, None, None),
 593            ))?;
 594        Ok(())
 595    }
 596
 597    pub fn forward_send<T: EnvelopedMessage>(
 598        &self,
 599        sender_id: ConnectionId,
 600        receiver_id: ConnectionId,
 601        message: T,
 602    ) -> Result<()> {
 603        let connection = self.connection_state(receiver_id)?;
 604        let message_id = connection
 605            .next_message_id
 606            .fetch_add(1, atomic::Ordering::SeqCst);
 607        connection
 608            .outgoing_tx
 609            .unbounded_send(proto::Message::Envelope(message.into_envelope(
 610                message_id,
 611                None,
 612                Some(sender_id.into()),
 613            )))?;
 614        Ok(())
 615    }
 616
 617    pub fn respond<T: RequestMessage>(
 618        &self,
 619        receipt: Receipt<T>,
 620        response: T::Response,
 621    ) -> Result<()> {
 622        let connection = self.connection_state(receipt.sender_id)?;
 623        let message_id = connection
 624            .next_message_id
 625            .fetch_add(1, atomic::Ordering::SeqCst);
 626        connection
 627            .outgoing_tx
 628            .unbounded_send(proto::Message::Envelope(response.into_envelope(
 629                message_id,
 630                Some(receipt.message_id),
 631                None,
 632            )))?;
 633        Ok(())
 634    }
 635
 636    pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
 637        let connection = self.connection_state(receipt.sender_id)?;
 638        let message_id = connection
 639            .next_message_id
 640            .fetch_add(1, atomic::Ordering::SeqCst);
 641
 642        let message = proto::EndStream {};
 643
 644        connection
 645            .outgoing_tx
 646            .unbounded_send(proto::Message::Envelope(message.into_envelope(
 647                message_id,
 648                Some(receipt.message_id),
 649                None,
 650            )))?;
 651        Ok(())
 652    }
 653
 654    pub fn respond_with_error<T: RequestMessage>(
 655        &self,
 656        receipt: Receipt<T>,
 657        response: proto::Error,
 658    ) -> Result<()> {
 659        let connection = self.connection_state(receipt.sender_id)?;
 660        let message_id = connection
 661            .next_message_id
 662            .fetch_add(1, atomic::Ordering::SeqCst);
 663        connection
 664            .outgoing_tx
 665            .unbounded_send(proto::Message::Envelope(response.into_envelope(
 666                message_id,
 667                Some(receipt.message_id),
 668                None,
 669            )))?;
 670        Ok(())
 671    }
 672
 673    pub fn respond_with_unhandled_message(
 674        &self,
 675        envelope: Box<dyn AnyTypedEnvelope>,
 676    ) -> Result<()> {
 677        let connection = self.connection_state(envelope.sender_id())?;
 678        let response = ErrorCode::Internal
 679            .message(format!(
 680                "message {} was not handled",
 681                envelope.payload_type_name()
 682            ))
 683            .to_proto();
 684        let message_id = connection
 685            .next_message_id
 686            .fetch_add(1, atomic::Ordering::SeqCst);
 687        connection
 688            .outgoing_tx
 689            .unbounded_send(proto::Message::Envelope(response.into_envelope(
 690                message_id,
 691                Some(envelope.message_id()),
 692                None,
 693            )))?;
 694        Ok(())
 695    }
 696
 697    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
 698        let connections = self.connections.read();
 699        let connection = connections
 700            .get(&connection_id)
 701            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
 702        Ok(connection.clone())
 703    }
 704}
 705
 706impl Serialize for Peer {
 707    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 708    where
 709        S: serde::Serializer,
 710    {
 711        let mut state = serializer.serialize_struct("Peer", 2)?;
 712        state.serialize_field("connections", &*self.connections.read())?;
 713        state.end()
 714    }
 715}
 716
 717#[cfg(test)]
 718mod tests {
 719    use super::*;
 720    use crate::TypedEnvelope;
 721    use async_tungstenite::tungstenite::Message as WebSocketMessage;
 722    use gpui::TestAppContext;
 723
 724    fn init_logger() {
 725        if std::env::var("RUST_LOG").is_ok() {
 726            env_logger::init();
 727        }
 728    }
 729
 730    #[gpui::test(iterations = 50)]
 731    async fn test_request_response(cx: &mut TestAppContext) {
 732        init_logger();
 733
 734        let executor = cx.executor();
 735
 736        // create 2 clients connected to 1 server
 737        let server = Peer::new(0);
 738        let client1 = Peer::new(0);
 739        let client2 = Peer::new(0);
 740
 741        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
 742            Connection::in_memory(cx.executor());
 743        let (client1_conn_id, io_task1, client1_incoming) =
 744            client1.add_test_connection(client1_to_server_conn, cx.executor());
 745        let (_, io_task2, server_incoming1) =
 746            server.add_test_connection(server_to_client_1_conn, cx.executor());
 747
 748        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
 749            Connection::in_memory(cx.executor());
 750        let (client2_conn_id, io_task3, client2_incoming) =
 751            client2.add_test_connection(client2_to_server_conn, cx.executor());
 752        let (_, io_task4, server_incoming2) =
 753            server.add_test_connection(server_to_client_2_conn, cx.executor());
 754
 755        executor.spawn(io_task1).detach();
 756        executor.spawn(io_task2).detach();
 757        executor.spawn(io_task3).detach();
 758        executor.spawn(io_task4).detach();
 759        executor
 760            .spawn(handle_messages(server_incoming1, server.clone()))
 761            .detach();
 762        executor
 763            .spawn(handle_messages(client1_incoming, client1.clone()))
 764            .detach();
 765        executor
 766            .spawn(handle_messages(server_incoming2, server.clone()))
 767            .detach();
 768        executor
 769            .spawn(handle_messages(client2_incoming, client2.clone()))
 770            .detach();
 771
 772        assert_eq!(
 773            client1
 774                .request(client1_conn_id, proto::Ping {},)
 775                .await
 776                .unwrap(),
 777            proto::Ack {}
 778        );
 779
 780        assert_eq!(
 781            client2
 782                .request(client2_conn_id, proto::Ping {},)
 783                .await
 784                .unwrap(),
 785            proto::Ack {}
 786        );
 787
 788        assert_eq!(
 789            client1
 790                .request(client1_conn_id, proto::Test { id: 1 },)
 791                .await
 792                .unwrap(),
 793            proto::Test { id: 1 }
 794        );
 795
 796        assert_eq!(
 797            client2
 798                .request(client2_conn_id, proto::Test { id: 2 })
 799                .await
 800                .unwrap(),
 801            proto::Test { id: 2 }
 802        );
 803
 804        client1.disconnect(client1_conn_id);
 805        client2.disconnect(client1_conn_id);
 806
 807        async fn handle_messages(
 808            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
 809            peer: Arc<Peer>,
 810        ) -> Result<()> {
 811            while let Some(envelope) = messages.next().await {
 812                let envelope = envelope.into_any();
 813                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
 814                    let receipt = envelope.receipt();
 815                    peer.respond(receipt, proto::Ack {})?
 816                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
 817                {
 818                    peer.respond(envelope.receipt(), envelope.payload.clone())?
 819                } else {
 820                    panic!("unknown message type");
 821                }
 822            }
 823
 824            Ok(())
 825        }
 826    }
 827
 828    #[gpui::test(iterations = 50)]
 829    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
 830        let executor = cx.executor();
 831        let server = Peer::new(0);
 832        let client = Peer::new(0);
 833
 834        let (client_to_server_conn, server_to_client_conn, _kill) =
 835            Connection::in_memory(executor.clone());
 836        let (client_to_server_conn_id, io_task1, mut client_incoming) =
 837            client.add_test_connection(client_to_server_conn, executor.clone());
 838
 839        let (server_to_client_conn_id, io_task2, mut server_incoming) =
 840            server.add_test_connection(server_to_client_conn, executor.clone());
 841
 842        executor.spawn(io_task1).detach();
 843        executor.spawn(io_task2).detach();
 844
 845        executor
 846            .spawn(async move {
 847                let future = server_incoming.next().await;
 848                let request = future
 849                    .unwrap()
 850                    .into_any()
 851                    .downcast::<TypedEnvelope<proto::Ping>>()
 852                    .unwrap();
 853
 854                server
 855                    .send(
 856                        server_to_client_conn_id,
 857                        ErrorCode::Internal
 858                            .message("message 1".to_string())
 859                            .to_proto(),
 860                    )
 861                    .unwrap();
 862                server
 863                    .send(
 864                        server_to_client_conn_id,
 865                        ErrorCode::Internal
 866                            .message("message 2".to_string())
 867                            .to_proto(),
 868                    )
 869                    .unwrap();
 870                server.respond(request.receipt(), proto::Ack {}).unwrap();
 871
 872                // Prevent the connection from being dropped
 873                server_incoming.next().await;
 874            })
 875            .detach();
 876
 877        let events = Arc::new(Mutex::new(Vec::new()));
 878
 879        let response = client.request(client_to_server_conn_id, proto::Ping {});
 880        let response_task = executor.spawn({
 881            let events = events.clone();
 882            async move {
 883                response.await.unwrap();
 884                events.lock().push("response".to_string());
 885            }
 886        });
 887
 888        executor
 889            .spawn({
 890                let events = events.clone();
 891                async move {
 892                    let incoming1 = client_incoming
 893                        .next()
 894                        .await
 895                        .unwrap()
 896                        .into_any()
 897                        .downcast::<TypedEnvelope<proto::Error>>()
 898                        .unwrap();
 899                    events.lock().push(incoming1.payload.message);
 900                    let incoming2 = client_incoming
 901                        .next()
 902                        .await
 903                        .unwrap()
 904                        .into_any()
 905                        .downcast::<TypedEnvelope<proto::Error>>()
 906                        .unwrap();
 907                    events.lock().push(incoming2.payload.message);
 908
 909                    // Prevent the connection from being dropped
 910                    client_incoming.next().await;
 911                }
 912            })
 913            .detach();
 914
 915        response_task.await;
 916        assert_eq!(
 917            &*events.lock(),
 918            &[
 919                "message 1".to_string(),
 920                "message 2".to_string(),
 921                "response".to_string()
 922            ]
 923        );
 924    }
 925
 926    #[gpui::test(iterations = 50)]
 927    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
 928        let executor = cx.executor();
 929        let server = Peer::new(0);
 930        let client = Peer::new(0);
 931
 932        let (client_to_server_conn, server_to_client_conn, _kill) =
 933            Connection::in_memory(cx.executor());
 934        let (client_to_server_conn_id, io_task1, mut client_incoming) =
 935            client.add_test_connection(client_to_server_conn, cx.executor());
 936        let (server_to_client_conn_id, io_task2, mut server_incoming) =
 937            server.add_test_connection(server_to_client_conn, cx.executor());
 938
 939        executor.spawn(io_task1).detach();
 940        executor.spawn(io_task2).detach();
 941
 942        executor
 943            .spawn(async move {
 944                let request1 = server_incoming
 945                    .next()
 946                    .await
 947                    .unwrap()
 948                    .into_any()
 949                    .downcast::<TypedEnvelope<proto::Ping>>()
 950                    .unwrap();
 951                let request2 = server_incoming
 952                    .next()
 953                    .await
 954                    .unwrap()
 955                    .into_any()
 956                    .downcast::<TypedEnvelope<proto::Ping>>()
 957                    .unwrap();
 958
 959                server
 960                    .send(
 961                        server_to_client_conn_id,
 962                        ErrorCode::Internal
 963                            .message("message 1".to_string())
 964                            .to_proto(),
 965                    )
 966                    .unwrap();
 967                server
 968                    .send(
 969                        server_to_client_conn_id,
 970                        ErrorCode::Internal
 971                            .message("message 2".to_string())
 972                            .to_proto(),
 973                    )
 974                    .unwrap();
 975                server.respond(request1.receipt(), proto::Ack {}).unwrap();
 976                server.respond(request2.receipt(), proto::Ack {}).unwrap();
 977
 978                // Prevent the connection from being dropped
 979                server_incoming.next().await;
 980            })
 981            .detach();
 982
 983        let events = Arc::new(Mutex::new(Vec::new()));
 984
 985        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
 986        let request1_task = executor.spawn(request1);
 987        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
 988        let request2_task = executor.spawn({
 989            let events = events.clone();
 990            async move {
 991                request2.await.unwrap();
 992                events.lock().push("response 2".to_string());
 993            }
 994        });
 995
 996        executor
 997            .spawn({
 998                let events = events.clone();
 999                async move {
1000                    let incoming1 = client_incoming
1001                        .next()
1002                        .await
1003                        .unwrap()
1004                        .into_any()
1005                        .downcast::<TypedEnvelope<proto::Error>>()
1006                        .unwrap();
1007                    events.lock().push(incoming1.payload.message);
1008                    let incoming2 = client_incoming
1009                        .next()
1010                        .await
1011                        .unwrap()
1012                        .into_any()
1013                        .downcast::<TypedEnvelope<proto::Error>>()
1014                        .unwrap();
1015                    events.lock().push(incoming2.payload.message);
1016
1017                    // Prevent the connection from being dropped
1018                    client_incoming.next().await;
1019                }
1020            })
1021            .detach();
1022
1023        // Allow the request to make some progress before dropping it.
1024        cx.executor().simulate_random_delay().await;
1025        drop(request1_task);
1026
1027        request2_task.await;
1028        assert_eq!(
1029            &*events.lock(),
1030            &[
1031                "message 1".to_string(),
1032                "message 2".to_string(),
1033                "response 2".to_string()
1034            ]
1035        );
1036    }
1037
1038    #[gpui::test(iterations = 50)]
1039    async fn test_disconnect(cx: &mut TestAppContext) {
1040        let executor = cx.executor();
1041
1042        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1043
1044        let client = Peer::new(0);
1045        let (connection_id, io_handler, mut incoming) =
1046            client.add_test_connection(client_conn, executor.clone());
1047
1048        let (io_ended_tx, io_ended_rx) = oneshot::channel();
1049        executor
1050            .spawn(async move {
1051                io_handler.await.ok();
1052                io_ended_tx.send(()).unwrap();
1053            })
1054            .detach();
1055
1056        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
1057        executor
1058            .spawn(async move {
1059                incoming.next().await;
1060                messages_ended_tx.send(()).unwrap();
1061            })
1062            .detach();
1063
1064        client.disconnect(connection_id);
1065
1066        let _ = io_ended_rx.await;
1067        let _ = messages_ended_rx.await;
1068        assert!(server_conn
1069            .send(WebSocketMessage::Binary(vec![]))
1070            .await
1071            .is_err());
1072    }
1073
1074    #[gpui::test(iterations = 50)]
1075    async fn test_io_error(cx: &mut TestAppContext) {
1076        let executor = cx.executor();
1077        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1078
1079        let client = Peer::new(0);
1080        let (connection_id, io_handler, mut incoming) =
1081            client.add_test_connection(client_conn, executor.clone());
1082        executor.spawn(io_handler).detach();
1083        executor
1084            .spawn(async move { incoming.next().await })
1085            .detach();
1086
1087        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
1088        let _request = server_conn.rx.next().await.unwrap().unwrap();
1089
1090        drop(server_conn);
1091        assert_eq!(
1092            response.await.unwrap_err().to_string(),
1093            "connection was closed"
1094        );
1095    }
1096}