peer.rs

  1use super::{
  2    proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
  3    Connection,
  4};
  5use anyhow::{anyhow, Context, Result};
  6use collections::HashMap;
  7use futures::{
  8    channel::{mpsc, oneshot},
  9    stream::BoxStream,
 10    FutureExt, SinkExt, StreamExt, TryFutureExt,
 11};
 12use parking_lot::{Mutex, RwLock};
 13use serde::{ser::SerializeStruct, Serialize};
 14use std::{fmt, sync::atomic::Ordering::SeqCst};
 15use std::{
 16    future::Future,
 17    marker::PhantomData,
 18    sync::{
 19        atomic::{self, AtomicU32},
 20        Arc,
 21    },
 22    time::Duration,
 23};
 24use tracing::instrument;
 25
 26#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
 27pub struct ConnectionId {
 28    pub owner_id: u32,
 29    pub id: u32,
 30}
 31
 32impl Into<PeerId> for ConnectionId {
 33    fn into(self) -> PeerId {
 34        PeerId {
 35            owner_id: self.owner_id,
 36            id: self.id,
 37        }
 38    }
 39}
 40
 41impl From<PeerId> for ConnectionId {
 42    fn from(peer_id: PeerId) -> Self {
 43        Self {
 44            owner_id: peer_id.owner_id,
 45            id: peer_id.id,
 46        }
 47    }
 48}
 49
 50impl fmt::Display for ConnectionId {
 51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 52        write!(f, "{}/{}", self.owner_id, self.id)
 53    }
 54}
 55
 56pub struct Receipt<T> {
 57    pub sender_id: ConnectionId,
 58    pub message_id: u32,
 59    payload_type: PhantomData<T>,
 60}
 61
 62impl<T> Clone for Receipt<T> {
 63    fn clone(&self) -> Self {
 64        Self {
 65            sender_id: self.sender_id,
 66            message_id: self.message_id,
 67            payload_type: PhantomData,
 68        }
 69    }
 70}
 71
 72impl<T> Copy for Receipt<T> {}
 73
 74#[derive(Clone, Debug)]
 75pub struct TypedEnvelope<T> {
 76    pub sender_id: ConnectionId,
 77    pub original_sender_id: Option<PeerId>,
 78    pub message_id: u32,
 79    pub payload: T,
 80}
 81
 82impl<T> TypedEnvelope<T> {
 83    pub fn original_sender_id(&self) -> Result<PeerId> {
 84        self.original_sender_id
 85            .ok_or_else(|| anyhow!("missing original_sender_id"))
 86    }
 87}
 88
 89impl<T: RequestMessage> TypedEnvelope<T> {
 90    pub fn receipt(&self) -> Receipt<T> {
 91        Receipt {
 92            sender_id: self.sender_id,
 93            message_id: self.message_id,
 94            payload_type: PhantomData,
 95        }
 96    }
 97}
 98
 99pub struct Peer {
100    epoch: AtomicU32,
101    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
102    next_connection_id: AtomicU32,
103}
104
105#[derive(Clone, Serialize)]
106pub struct ConnectionState {
107    #[serde(skip)]
108    outgoing_tx: mpsc::UnboundedSender<proto::Message>,
109    next_message_id: Arc<AtomicU32>,
110    #[allow(clippy::type_complexity)]
111    #[serde(skip)]
112    response_channels:
113        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
114}
115
116const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
117const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
118pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
119
120impl Peer {
121    pub fn new(epoch: u32) -> Arc<Self> {
122        Arc::new(Self {
123            epoch: AtomicU32::new(epoch),
124            connections: Default::default(),
125            next_connection_id: Default::default(),
126        })
127    }
128
129    pub fn epoch(&self) -> u32 {
130        self.epoch.load(SeqCst)
131    }
132
133    #[instrument(skip_all)]
134    pub fn add_connection<F, Fut, Out>(
135        self: &Arc<Self>,
136        connection: Connection,
137        create_timer: F,
138    ) -> (
139        ConnectionId,
140        impl Future<Output = anyhow::Result<()>> + Send,
141        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
142    )
143    where
144        F: Send + Fn(Duration) -> Fut,
145        Fut: Send + Future<Output = Out>,
146        Out: Send,
147    {
148        // For outgoing messages, use an unbounded channel so that application code
149        // can always send messages without yielding. For incoming messages, use a
150        // bounded channel so that other peers will receive backpressure if they send
151        // messages faster than this peer can process them.
152        #[cfg(any(test, feature = "test-support"))]
153        const INCOMING_BUFFER_SIZE: usize = 1;
154        #[cfg(not(any(test, feature = "test-support")))]
155        const INCOMING_BUFFER_SIZE: usize = 64;
156        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
157        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
158
159        let connection_id = ConnectionId {
160            owner_id: self.epoch.load(SeqCst),
161            id: self.next_connection_id.fetch_add(1, SeqCst),
162        };
163        let connection_state = ConnectionState {
164            outgoing_tx,
165            next_message_id: Default::default(),
166            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
167        };
168        let mut writer = MessageStream::new(connection.tx);
169        let mut reader = MessageStream::new(connection.rx);
170
171        let this = self.clone();
172        let response_channels = connection_state.response_channels.clone();
173        let handle_io = async move {
174            tracing::trace!(%connection_id, "handle io future: start");
175
176            let _end_connection = util::defer(|| {
177                response_channels.lock().take();
178                this.connections.write().remove(&connection_id);
179                tracing::trace!(%connection_id, "handle io future: end");
180            });
181
182            // Send messages on this frequency so the connection isn't closed.
183            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
184            futures::pin_mut!(keepalive_timer);
185
186            // Disconnect if we don't receive messages at least this frequently.
187            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
188            futures::pin_mut!(receive_timeout);
189
190            loop {
191                tracing::trace!(%connection_id, "outer loop iteration start");
192                let read_message = reader.read().fuse();
193                futures::pin_mut!(read_message);
194
195                loop {
196                    tracing::trace!(%connection_id, "inner loop iteration start");
197                    futures::select_biased! {
198                        outgoing = outgoing_rx.next().fuse() => match outgoing {
199                            Some(outgoing) => {
200                                tracing::trace!(%connection_id, "outgoing rpc message: writing");
201                                futures::select_biased! {
202                                    result = writer.write(outgoing).fuse() => {
203                                        tracing::trace!(%connection_id, "outgoing rpc message: done writing");
204                                        result.context("failed to write RPC message")?;
205                                        tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
206                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
207                                    }
208                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
209                                        tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
210                                        Err(anyhow!("timed out writing message"))?;
211                                    }
212                                }
213                            }
214                            None => {
215                                tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
216                                return Ok(())
217                            },
218                        },
219                        _ = keepalive_timer => {
220                            tracing::trace!(%connection_id, "keepalive interval: pinging");
221                            futures::select_biased! {
222                                result = writer.write(proto::Message::Ping).fuse() => {
223                                    tracing::trace!(%connection_id, "keepalive interval: done pinging");
224                                    result.context("failed to send keepalive")?;
225                                    tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
226                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
227                                }
228                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
229                                    tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
230                                    Err(anyhow!("timed out sending keepalive"))?;
231                                }
232                            }
233                        }
234                        incoming = read_message => {
235                            let incoming = incoming.context("error reading rpc message from socket")?;
236                            tracing::trace!(%connection_id, "incoming rpc message: received");
237                            tracing::trace!(%connection_id, "receive timeout: resetting");
238                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
239                            if let proto::Message::Envelope(incoming) = incoming {
240                                tracing::trace!(%connection_id, "incoming rpc message: processing");
241                                futures::select_biased! {
242                                    result = incoming_tx.send(incoming).fuse() => match result {
243                                        Ok(_) => {
244                                            tracing::trace!(%connection_id, "incoming rpc message: processed");
245                                        }
246                                        Err(_) => {
247                                            tracing::trace!(%connection_id, "incoming rpc message: channel closed");
248                                            return Ok(())
249                                        }
250                                    },
251                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
252                                        tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
253                                        Err(anyhow!("timed out processing incoming message"))?
254                                    }
255                                }
256                            }
257                            break;
258                        },
259                        _ = receive_timeout => {
260                            tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
261                            Err(anyhow!("delay between messages too long"))?
262                        }
263                    }
264                }
265            }
266        };
267
268        let response_channels = connection_state.response_channels.clone();
269        self.connections
270            .write()
271            .insert(connection_id, connection_state);
272
273        let incoming_rx = incoming_rx.filter_map(move |incoming| {
274            let response_channels = response_channels.clone();
275            async move {
276                let message_id = incoming.id;
277                tracing::trace!(?incoming, "incoming message future: start");
278                let _end = util::defer(move || {
279                    tracing::trace!(%connection_id, message_id, "incoming message future: end");
280                });
281
282                if let Some(responding_to) = incoming.responding_to {
283                    tracing::trace!(
284                        %connection_id,
285                        message_id,
286                        responding_to,
287                        "incoming response: received"
288                    );
289                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
290                    if let Some(tx) = channel {
291                        let requester_resumed = oneshot::channel();
292                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
293                            tracing::trace!(
294                                %connection_id,
295                                message_id,
296                                responding_to = responding_to,
297                                ?error,
298                                "incoming response: request future dropped",
299                            );
300                        }
301
302                        tracing::trace!(
303                            %connection_id,
304                            message_id,
305                            responding_to,
306                            "incoming response: waiting to resume requester"
307                        );
308                        let _ = requester_resumed.1.await;
309                        tracing::trace!(
310                            %connection_id,
311                            message_id,
312                            responding_to,
313                            "incoming response: requester resumed"
314                        );
315                    } else {
316                        tracing::warn!(
317                            %connection_id,
318                            message_id,
319                            responding_to,
320                            "incoming response: unknown request"
321                        );
322                    }
323
324                    None
325                } else {
326                    tracing::trace!(%connection_id, message_id, "incoming message: received");
327                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
328                        tracing::error!(
329                            %connection_id,
330                            message_id,
331                            "unable to construct a typed envelope"
332                        );
333                        None
334                    })
335                }
336            }
337        });
338        (connection_id, handle_io, incoming_rx.boxed())
339    }
340
341    #[cfg(any(test, feature = "test-support"))]
342    pub fn add_test_connection(
343        self: &Arc<Self>,
344        connection: Connection,
345        executor: Arc<gpui::executor::Background>,
346    ) -> (
347        ConnectionId,
348        impl Future<Output = anyhow::Result<()>> + Send,
349        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
350    ) {
351        let executor = executor.clone();
352        self.add_connection(connection, move |duration| executor.timer(duration))
353    }
354
355    pub fn disconnect(&self, connection_id: ConnectionId) {
356        self.connections.write().remove(&connection_id);
357    }
358
359    pub fn reset(&self, epoch: u32) {
360        self.teardown();
361        self.next_connection_id.store(0, SeqCst);
362        self.epoch.store(epoch, SeqCst);
363    }
364
365    pub fn teardown(&self) {
366        self.connections.write().clear();
367    }
368
369    pub fn request<T: RequestMessage>(
370        &self,
371        receiver_id: ConnectionId,
372        request: T,
373    ) -> impl Future<Output = Result<T::Response>> {
374        self.request_internal(None, receiver_id, request)
375            .map_ok(|envelope| envelope.payload)
376    }
377
378    pub fn request_envelope<T: RequestMessage>(
379        &self,
380        receiver_id: ConnectionId,
381        request: T,
382    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
383        self.request_internal(None, receiver_id, request)
384    }
385
386    pub fn forward_request<T: RequestMessage>(
387        &self,
388        sender_id: ConnectionId,
389        receiver_id: ConnectionId,
390        request: T,
391    ) -> impl Future<Output = Result<T::Response>> {
392        self.request_internal(Some(sender_id), receiver_id, request)
393            .map_ok(|envelope| envelope.payload)
394    }
395
396    pub fn request_internal<T: RequestMessage>(
397        &self,
398        original_sender_id: Option<ConnectionId>,
399        receiver_id: ConnectionId,
400        request: T,
401    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
402        let (tx, rx) = oneshot::channel();
403        let send = self.connection_state(receiver_id).and_then(|connection| {
404            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
405            connection
406                .response_channels
407                .lock()
408                .as_mut()
409                .ok_or_else(|| anyhow!("connection was closed"))?
410                .insert(message_id, tx);
411            connection
412                .outgoing_tx
413                .unbounded_send(proto::Message::Envelope(request.into_envelope(
414                    message_id,
415                    None,
416                    original_sender_id.map(Into::into),
417                )))
418                .map_err(|_| anyhow!("connection was closed"))?;
419            Ok(())
420        });
421        async move {
422            send?;
423            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
424
425            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
426                Err(anyhow!(
427                    "RPC request {} failed - {}",
428                    T::NAME,
429                    error.message
430                ))
431            } else {
432                Ok(TypedEnvelope {
433                    message_id: response.id,
434                    sender_id: receiver_id,
435                    original_sender_id: response.original_sender_id,
436                    payload: T::Response::from_envelope(response)
437                        .ok_or_else(|| anyhow!("received response of the wrong type"))?,
438                })
439            }
440        }
441    }
442
443    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
444        let connection = self.connection_state(receiver_id)?;
445        let message_id = connection
446            .next_message_id
447            .fetch_add(1, atomic::Ordering::SeqCst);
448        connection
449            .outgoing_tx
450            .unbounded_send(proto::Message::Envelope(
451                message.into_envelope(message_id, None, None),
452            ))?;
453        Ok(())
454    }
455
456    pub fn forward_send<T: EnvelopedMessage>(
457        &self,
458        sender_id: ConnectionId,
459        receiver_id: ConnectionId,
460        message: T,
461    ) -> Result<()> {
462        let connection = self.connection_state(receiver_id)?;
463        let message_id = connection
464            .next_message_id
465            .fetch_add(1, atomic::Ordering::SeqCst);
466        connection
467            .outgoing_tx
468            .unbounded_send(proto::Message::Envelope(message.into_envelope(
469                message_id,
470                None,
471                Some(sender_id.into()),
472            )))?;
473        Ok(())
474    }
475
476    pub fn respond<T: RequestMessage>(
477        &self,
478        receipt: Receipt<T>,
479        response: T::Response,
480    ) -> Result<()> {
481        let connection = self.connection_state(receipt.sender_id)?;
482        let message_id = connection
483            .next_message_id
484            .fetch_add(1, atomic::Ordering::SeqCst);
485        connection
486            .outgoing_tx
487            .unbounded_send(proto::Message::Envelope(response.into_envelope(
488                message_id,
489                Some(receipt.message_id),
490                None,
491            )))?;
492        Ok(())
493    }
494
495    pub fn respond_with_error<T: RequestMessage>(
496        &self,
497        receipt: Receipt<T>,
498        response: proto::Error,
499    ) -> Result<()> {
500        let connection = self.connection_state(receipt.sender_id)?;
501        let message_id = connection
502            .next_message_id
503            .fetch_add(1, atomic::Ordering::SeqCst);
504        connection
505            .outgoing_tx
506            .unbounded_send(proto::Message::Envelope(response.into_envelope(
507                message_id,
508                Some(receipt.message_id),
509                None,
510            )))?;
511        Ok(())
512    }
513
514    pub fn respond_with_unhandled_message(
515        &self,
516        envelope: Box<dyn AnyTypedEnvelope>,
517    ) -> Result<()> {
518        let connection = self.connection_state(envelope.sender_id())?;
519        let response = proto::Error {
520            message: format!("message {} was not handled", envelope.payload_type_name()),
521        };
522        let message_id = connection
523            .next_message_id
524            .fetch_add(1, atomic::Ordering::SeqCst);
525        connection
526            .outgoing_tx
527            .unbounded_send(proto::Message::Envelope(response.into_envelope(
528                message_id,
529                Some(envelope.message_id()),
530                None,
531            )))?;
532        Ok(())
533    }
534
535    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
536        let connections = self.connections.read();
537        let connection = connections
538            .get(&connection_id)
539            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
540        Ok(connection.clone())
541    }
542}
543
544impl Serialize for Peer {
545    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
546    where
547        S: serde::Serializer,
548    {
549        let mut state = serializer.serialize_struct("Peer", 2)?;
550        state.serialize_field("connections", &*self.connections.read())?;
551        state.end()
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use crate::TypedEnvelope;
559    use async_tungstenite::tungstenite::Message as WebSocketMessage;
560    use gpui::TestAppContext;
561
562    #[ctor::ctor]
563    fn init_logger() {
564        if std::env::var("RUST_LOG").is_ok() {
565            env_logger::init();
566        }
567    }
568
569    #[gpui::test(iterations = 50)]
570    async fn test_request_response(cx: &mut TestAppContext) {
571        let executor = cx.foreground();
572
573        // create 2 clients connected to 1 server
574        let server = Peer::new(0);
575        let client1 = Peer::new(0);
576        let client2 = Peer::new(0);
577
578        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
579            Connection::in_memory(cx.background());
580        let (client1_conn_id, io_task1, client1_incoming) =
581            client1.add_test_connection(client1_to_server_conn, cx.background());
582        let (_, io_task2, server_incoming1) =
583            server.add_test_connection(server_to_client_1_conn, cx.background());
584
585        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
586            Connection::in_memory(cx.background());
587        let (client2_conn_id, io_task3, client2_incoming) =
588            client2.add_test_connection(client2_to_server_conn, cx.background());
589        let (_, io_task4, server_incoming2) =
590            server.add_test_connection(server_to_client_2_conn, cx.background());
591
592        executor.spawn(io_task1).detach();
593        executor.spawn(io_task2).detach();
594        executor.spawn(io_task3).detach();
595        executor.spawn(io_task4).detach();
596        executor
597            .spawn(handle_messages(server_incoming1, server.clone()))
598            .detach();
599        executor
600            .spawn(handle_messages(client1_incoming, client1.clone()))
601            .detach();
602        executor
603            .spawn(handle_messages(server_incoming2, server.clone()))
604            .detach();
605        executor
606            .spawn(handle_messages(client2_incoming, client2.clone()))
607            .detach();
608
609        assert_eq!(
610            client1
611                .request(client1_conn_id, proto::Ping {},)
612                .await
613                .unwrap(),
614            proto::Ack {}
615        );
616
617        assert_eq!(
618            client2
619                .request(client2_conn_id, proto::Ping {},)
620                .await
621                .unwrap(),
622            proto::Ack {}
623        );
624
625        assert_eq!(
626            client1
627                .request(client1_conn_id, proto::Test { id: 1 },)
628                .await
629                .unwrap(),
630            proto::Test { id: 1 }
631        );
632
633        assert_eq!(
634            client2
635                .request(client2_conn_id, proto::Test { id: 2 })
636                .await
637                .unwrap(),
638            proto::Test { id: 2 }
639        );
640
641        client1.disconnect(client1_conn_id);
642        client2.disconnect(client1_conn_id);
643
644        async fn handle_messages(
645            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
646            peer: Arc<Peer>,
647        ) -> Result<()> {
648            while let Some(envelope) = messages.next().await {
649                let envelope = envelope.into_any();
650                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
651                    let receipt = envelope.receipt();
652                    peer.respond(receipt, proto::Ack {})?
653                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
654                {
655                    peer.respond(envelope.receipt(), envelope.payload.clone())?
656                } else {
657                    panic!("unknown message type");
658                }
659            }
660
661            Ok(())
662        }
663    }
664
665    #[gpui::test(iterations = 50)]
666    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
667        let executor = cx.foreground();
668        let server = Peer::new(0);
669        let client = Peer::new(0);
670
671        let (client_to_server_conn, server_to_client_conn, _kill) =
672            Connection::in_memory(cx.background());
673        let (client_to_server_conn_id, io_task1, mut client_incoming) =
674            client.add_test_connection(client_to_server_conn, cx.background());
675        let (server_to_client_conn_id, io_task2, mut server_incoming) =
676            server.add_test_connection(server_to_client_conn, cx.background());
677
678        executor.spawn(io_task1).detach();
679        executor.spawn(io_task2).detach();
680
681        executor
682            .spawn(async move {
683                let request = server_incoming
684                    .next()
685                    .await
686                    .unwrap()
687                    .into_any()
688                    .downcast::<TypedEnvelope<proto::Ping>>()
689                    .unwrap();
690
691                server
692                    .send(
693                        server_to_client_conn_id,
694                        proto::Error {
695                            message: "message 1".to_string(),
696                        },
697                    )
698                    .unwrap();
699                server
700                    .send(
701                        server_to_client_conn_id,
702                        proto::Error {
703                            message: "message 2".to_string(),
704                        },
705                    )
706                    .unwrap();
707                server.respond(request.receipt(), proto::Ack {}).unwrap();
708
709                // Prevent the connection from being dropped
710                server_incoming.next().await;
711            })
712            .detach();
713
714        let events = Arc::new(Mutex::new(Vec::new()));
715
716        let response = client.request(client_to_server_conn_id, proto::Ping {});
717        let response_task = executor.spawn({
718            let events = events.clone();
719            async move {
720                response.await.unwrap();
721                events.lock().push("response".to_string());
722            }
723        });
724
725        executor
726            .spawn({
727                let events = events.clone();
728                async move {
729                    let incoming1 = client_incoming
730                        .next()
731                        .await
732                        .unwrap()
733                        .into_any()
734                        .downcast::<TypedEnvelope<proto::Error>>()
735                        .unwrap();
736                    events.lock().push(incoming1.payload.message);
737                    let incoming2 = client_incoming
738                        .next()
739                        .await
740                        .unwrap()
741                        .into_any()
742                        .downcast::<TypedEnvelope<proto::Error>>()
743                        .unwrap();
744                    events.lock().push(incoming2.payload.message);
745
746                    // Prevent the connection from being dropped
747                    client_incoming.next().await;
748                }
749            })
750            .detach();
751
752        response_task.await;
753        assert_eq!(
754            &*events.lock(),
755            &[
756                "message 1".to_string(),
757                "message 2".to_string(),
758                "response".to_string()
759            ]
760        );
761    }
762
763    #[gpui::test(iterations = 50)]
764    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
765        let executor = cx.foreground();
766        let server = Peer::new(0);
767        let client = Peer::new(0);
768
769        let (client_to_server_conn, server_to_client_conn, _kill) =
770            Connection::in_memory(cx.background());
771        let (client_to_server_conn_id, io_task1, mut client_incoming) =
772            client.add_test_connection(client_to_server_conn, cx.background());
773        let (server_to_client_conn_id, io_task2, mut server_incoming) =
774            server.add_test_connection(server_to_client_conn, cx.background());
775
776        executor.spawn(io_task1).detach();
777        executor.spawn(io_task2).detach();
778
779        executor
780            .spawn(async move {
781                let request1 = server_incoming
782                    .next()
783                    .await
784                    .unwrap()
785                    .into_any()
786                    .downcast::<TypedEnvelope<proto::Ping>>()
787                    .unwrap();
788                let request2 = server_incoming
789                    .next()
790                    .await
791                    .unwrap()
792                    .into_any()
793                    .downcast::<TypedEnvelope<proto::Ping>>()
794                    .unwrap();
795
796                server
797                    .send(
798                        server_to_client_conn_id,
799                        proto::Error {
800                            message: "message 1".to_string(),
801                        },
802                    )
803                    .unwrap();
804                server
805                    .send(
806                        server_to_client_conn_id,
807                        proto::Error {
808                            message: "message 2".to_string(),
809                        },
810                    )
811                    .unwrap();
812                server.respond(request1.receipt(), proto::Ack {}).unwrap();
813                server.respond(request2.receipt(), proto::Ack {}).unwrap();
814
815                // Prevent the connection from being dropped
816                server_incoming.next().await;
817            })
818            .detach();
819
820        let events = Arc::new(Mutex::new(Vec::new()));
821
822        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
823        let request1_task = executor.spawn(request1);
824        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
825        let request2_task = executor.spawn({
826            let events = events.clone();
827            async move {
828                request2.await.unwrap();
829                events.lock().push("response 2".to_string());
830            }
831        });
832
833        executor
834            .spawn({
835                let events = events.clone();
836                async move {
837                    let incoming1 = client_incoming
838                        .next()
839                        .await
840                        .unwrap()
841                        .into_any()
842                        .downcast::<TypedEnvelope<proto::Error>>()
843                        .unwrap();
844                    events.lock().push(incoming1.payload.message);
845                    let incoming2 = client_incoming
846                        .next()
847                        .await
848                        .unwrap()
849                        .into_any()
850                        .downcast::<TypedEnvelope<proto::Error>>()
851                        .unwrap();
852                    events.lock().push(incoming2.payload.message);
853
854                    // Prevent the connection from being dropped
855                    client_incoming.next().await;
856                }
857            })
858            .detach();
859
860        // Allow the request to make some progress before dropping it.
861        cx.background().simulate_random_delay().await;
862        drop(request1_task);
863
864        request2_task.await;
865        assert_eq!(
866            &*events.lock(),
867            &[
868                "message 1".to_string(),
869                "message 2".to_string(),
870                "response 2".to_string()
871            ]
872        );
873    }
874
875    #[gpui::test(iterations = 50)]
876    async fn test_disconnect(cx: &mut TestAppContext) {
877        let executor = cx.foreground();
878
879        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
880
881        let client = Peer::new(0);
882        let (connection_id, io_handler, mut incoming) =
883            client.add_test_connection(client_conn, cx.background());
884
885        let (io_ended_tx, io_ended_rx) = oneshot::channel();
886        executor
887            .spawn(async move {
888                io_handler.await.ok();
889                io_ended_tx.send(()).unwrap();
890            })
891            .detach();
892
893        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
894        executor
895            .spawn(async move {
896                incoming.next().await;
897                messages_ended_tx.send(()).unwrap();
898            })
899            .detach();
900
901        client.disconnect(connection_id);
902
903        let _ = io_ended_rx.await;
904        let _ = messages_ended_rx.await;
905        assert!(server_conn
906            .send(WebSocketMessage::Binary(vec![]))
907            .await
908            .is_err());
909    }
910
911    #[gpui::test(iterations = 50)]
912    async fn test_io_error(cx: &mut TestAppContext) {
913        let executor = cx.foreground();
914        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
915
916        let client = Peer::new(0);
917        let (connection_id, io_handler, mut incoming) =
918            client.add_test_connection(client_conn, cx.background());
919        executor.spawn(io_handler).detach();
920        executor
921            .spawn(async move { incoming.next().await })
922            .detach();
923
924        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
925        let _request = server_conn.rx.next().await.unwrap().unwrap();
926
927        drop(server_conn);
928        assert_eq!(
929            response.await.unwrap_err().to_string(),
930            "connection was closed"
931        );
932    }
933}