peer.rs

  1use super::{
  2    proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
  3    Connection,
  4};
  5use anyhow::{anyhow, Context, Result};
  6use collections::HashMap;
  7use futures::{
  8    channel::{mpsc, oneshot},
  9    stream::BoxStream,
 10    FutureExt, SinkExt, StreamExt,
 11};
 12use parking_lot::{Mutex, RwLock};
 13use smol_timeout::TimeoutExt;
 14use std::sync::atomic::Ordering::SeqCst;
 15use std::{
 16    fmt,
 17    future::Future,
 18    marker::PhantomData,
 19    sync::{
 20        atomic::{self, AtomicU32},
 21        Arc,
 22    },
 23    time::Duration,
 24};
 25use tracing::instrument;
 26
 27#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 28pub struct ConnectionId(pub u32);
 29
 30impl fmt::Display for ConnectionId {
 31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 32        self.0.fmt(f)
 33    }
 34}
 35
 36#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 37pub struct PeerId(pub u32);
 38
 39impl fmt::Display for PeerId {
 40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 41        self.0.fmt(f)
 42    }
 43}
 44
 45pub struct Receipt<T> {
 46    pub sender_id: ConnectionId,
 47    pub message_id: u32,
 48    payload_type: PhantomData<T>,
 49}
 50
 51impl<T> Clone for Receipt<T> {
 52    fn clone(&self) -> Self {
 53        Self {
 54            sender_id: self.sender_id,
 55            message_id: self.message_id,
 56            payload_type: PhantomData,
 57        }
 58    }
 59}
 60
 61impl<T> Copy for Receipt<T> {}
 62
 63pub struct TypedEnvelope<T> {
 64    pub sender_id: ConnectionId,
 65    pub original_sender_id: Option<PeerId>,
 66    pub message_id: u32,
 67    pub payload: T,
 68}
 69
 70impl<T> TypedEnvelope<T> {
 71    pub fn original_sender_id(&self) -> Result<PeerId> {
 72        self.original_sender_id
 73            .ok_or_else(|| anyhow!("missing original_sender_id"))
 74    }
 75}
 76
 77impl<T: RequestMessage> TypedEnvelope<T> {
 78    pub fn receipt(&self) -> Receipt<T> {
 79        Receipt {
 80            sender_id: self.sender_id,
 81            message_id: self.message_id,
 82            payload_type: PhantomData,
 83        }
 84    }
 85}
 86
 87pub struct Peer {
 88    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 89    next_connection_id: AtomicU32,
 90}
 91
 92#[derive(Clone)]
 93pub struct ConnectionState {
 94    outgoing_tx: mpsc::UnboundedSender<proto::Message>,
 95    next_message_id: Arc<AtomicU32>,
 96    response_channels:
 97        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
 98}
 99
100const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
101const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
102pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
103
104impl Peer {
105    pub fn new() -> Arc<Self> {
106        Arc::new(Self {
107            connections: Default::default(),
108            next_connection_id: Default::default(),
109        })
110    }
111
112    #[instrument(skip_all)]
113    pub async fn add_connection<F, Fut, Out>(
114        self: &Arc<Self>,
115        connection: Connection,
116        create_timer: F,
117    ) -> (
118        ConnectionId,
119        impl Future<Output = anyhow::Result<()>> + Send,
120        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
121    )
122    where
123        F: Send + Fn(Duration) -> Fut,
124        Fut: Send + Future<Output = Out>,
125        Out: Send,
126    {
127        // For outgoing messages, use an unbounded channel so that application code
128        // can always send messages without yielding. For incoming messages, use a
129        // bounded channel so that other peers will receive backpressure if they send
130        // messages faster than this peer can process them.
131        #[cfg(any(test, feature = "test-support"))]
132        const INCOMING_BUFFER_SIZE: usize = 1;
133        #[cfg(not(any(test, feature = "test-support")))]
134        const INCOMING_BUFFER_SIZE: usize = 64;
135        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
136        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
137
138        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
139        let connection_state = ConnectionState {
140            outgoing_tx: outgoing_tx.clone(),
141            next_message_id: Default::default(),
142            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
143        };
144        let mut writer = MessageStream::new(connection.tx);
145        let mut reader = MessageStream::new(connection.rx);
146
147        let this = self.clone();
148        let response_channels = connection_state.response_channels.clone();
149        let handle_io = async move {
150            tracing::debug!(%connection_id, "handle io future: start");
151
152            let _end_connection = util::defer(|| {
153                response_channels.lock().take();
154                this.connections.write().remove(&connection_id);
155                tracing::debug!(%connection_id, "handle io future: end");
156            });
157
158            // Send messages on this frequency so the connection isn't closed.
159            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
160            futures::pin_mut!(keepalive_timer);
161
162            // Disconnect if we don't receive messages at least this frequently.
163            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
164            futures::pin_mut!(receive_timeout);
165
166            loop {
167                tracing::debug!(%connection_id, "outer loop iteration start");
168                let read_message = reader.read().fuse();
169                futures::pin_mut!(read_message);
170
171                loop {
172                    tracing::debug!(%connection_id, "inner loop iteration start");
173                    futures::select_biased! {
174                        outgoing = outgoing_rx.next().fuse() => match outgoing {
175                            Some(outgoing) => {
176                                tracing::debug!(%connection_id, "outgoing rpc message: writing");
177                                if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
178                                    tracing::debug!(%connection_id, "outgoing rpc message: done writing");
179                                    result.context("failed to write RPC message")?;
180                                    tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
181                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
182                                } else {
183                                    tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
184                                    Err(anyhow!("timed out writing message"))?;
185                                }
186                            }
187                            None => {
188                                tracing::debug!(%connection_id, "outgoing rpc message: channel closed");
189                                return Ok(())
190                            },
191                        },
192                        incoming = read_message => {
193                            let incoming = incoming.context("error reading rpc message from socket")?;
194                            tracing::debug!(%connection_id, "incoming rpc message: received");
195                            tracing::debug!(%connection_id, "receive timeout: resetting");
196                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
197                            if let proto::Message::Envelope(incoming) = incoming {
198                                tracing::debug!(%connection_id, "incoming rpc message: processing");
199                                match incoming_tx.send(incoming).timeout(RECEIVE_TIMEOUT).await {
200                                    Some(Ok(_)) => {
201                                        tracing::debug!(%connection_id, "incoming rpc message: processed");
202                                    },
203                                    Some(Err(_)) => {
204                                        tracing::debug!(%connection_id, "incoming rpc message: channel closed");
205                                        return Ok(())
206                                    },
207                                    None => {
208                                        tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
209                                        Err(anyhow!("timed out processing incoming message"))?
210                                    },
211                                }
212                            }
213                            break;
214                        },
215                        _ = keepalive_timer => {
216                            tracing::debug!(%connection_id, "keepalive interval: pinging");
217                            if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
218                                tracing::debug!(%connection_id, "keepalive interval: done pinging");
219                                result.context("failed to send keepalive")?;
220                                tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
221                                keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
222                            } else {
223                                tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
224                                Err(anyhow!("timed out sending keepalive"))?;
225                            }
226                        }
227                        _ = receive_timeout => {
228                            tracing::debug!(%connection_id, "receive timeout: delay between messages too long");
229                            Err(anyhow!("delay between messages too long"))?
230                        }
231                    }
232                }
233            }
234        };
235
236        let response_channels = connection_state.response_channels.clone();
237        self.connections
238            .write()
239            .insert(connection_id, connection_state);
240
241        let incoming_rx = incoming_rx.filter_map(move |incoming| {
242            let response_channels = response_channels.clone();
243            async move {
244                let message_id = incoming.id;
245                tracing::debug!(?incoming, "incoming message future: start");
246                let _end = util::defer(move || {
247                    tracing::debug!(
248                        %connection_id,
249                        message_id,
250                        "incoming message future: end"
251                    );
252                });
253
254                if let Some(responding_to) = incoming.responding_to {
255                    tracing::debug!(
256                        %connection_id,
257                        message_id,
258                        responding_to,
259                        "incoming response: received"
260                    );
261                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
262                    if let Some(tx) = channel {
263                        let requester_resumed = oneshot::channel();
264                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
265                            tracing::debug!(
266                                %connection_id,
267                                message_id,
268                                responding_to = responding_to,
269                                ?error,
270                                "incoming response: request future dropped",
271                            );
272                        }
273
274                        tracing::debug!(
275                            %connection_id,
276                            message_id,
277                            responding_to,
278                            "incoming response: waiting to resume requester"
279                        );
280                        let _ = requester_resumed.1.await;
281                        tracing::debug!(
282                            %connection_id,
283                            message_id,
284                            responding_to,
285                            "incoming response: requester resumed"
286                        );
287                    } else {
288                        tracing::warn!(
289                            %connection_id,
290                            message_id,
291                            responding_to,
292                            "incoming response: unknown request"
293                        );
294                    }
295
296                    None
297                } else {
298                    tracing::debug!(
299                        %connection_id,
300                        message_id,
301                        "incoming message: received"
302                    );
303                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
304                        tracing::error!(
305                            %connection_id,
306                            message_id,
307                            "unable to construct a typed envelope"
308                        );
309                        None
310                    })
311                }
312            }
313        });
314        (connection_id, handle_io, incoming_rx.boxed())
315    }
316
317    #[cfg(any(test, feature = "test-support"))]
318    pub async fn add_test_connection(
319        self: &Arc<Self>,
320        connection: Connection,
321        executor: Arc<gpui::executor::Background>,
322    ) -> (
323        ConnectionId,
324        impl Future<Output = anyhow::Result<()>> + Send,
325        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
326    ) {
327        let executor = executor.clone();
328        self.add_connection(connection, move |duration| executor.timer(duration))
329            .await
330    }
331
332    pub fn disconnect(&self, connection_id: ConnectionId) {
333        self.connections.write().remove(&connection_id);
334    }
335
336    pub fn reset(&self) {
337        self.connections.write().clear();
338    }
339
340    pub fn request<T: RequestMessage>(
341        &self,
342        receiver_id: ConnectionId,
343        request: T,
344    ) -> impl Future<Output = Result<T::Response>> {
345        self.request_internal(None, receiver_id, request)
346    }
347
348    pub fn forward_request<T: RequestMessage>(
349        &self,
350        sender_id: ConnectionId,
351        receiver_id: ConnectionId,
352        request: T,
353    ) -> impl Future<Output = Result<T::Response>> {
354        self.request_internal(Some(sender_id), receiver_id, request)
355    }
356
357    pub fn request_internal<T: RequestMessage>(
358        &self,
359        original_sender_id: Option<ConnectionId>,
360        receiver_id: ConnectionId,
361        request: T,
362    ) -> impl Future<Output = Result<T::Response>> {
363        let (tx, rx) = oneshot::channel();
364        let send = self.connection_state(receiver_id).and_then(|connection| {
365            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
366            connection
367                .response_channels
368                .lock()
369                .as_mut()
370                .ok_or_else(|| anyhow!("connection was closed"))?
371                .insert(message_id, tx);
372            connection
373                .outgoing_tx
374                .unbounded_send(proto::Message::Envelope(request.into_envelope(
375                    message_id,
376                    None,
377                    original_sender_id.map(|id| id.0),
378                )))
379                .map_err(|_| anyhow!("connection was closed"))?;
380            Ok(())
381        });
382        async move {
383            send?;
384            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
385            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
386                Err(anyhow!("RPC request failed - {}", error.message))
387            } else {
388                T::Response::from_envelope(response)
389                    .ok_or_else(|| anyhow!("received response of the wrong type"))
390            }
391        }
392    }
393
394    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
395        let connection = self.connection_state(receiver_id)?;
396        let message_id = connection
397            .next_message_id
398            .fetch_add(1, atomic::Ordering::SeqCst);
399        connection
400            .outgoing_tx
401            .unbounded_send(proto::Message::Envelope(
402                message.into_envelope(message_id, None, None),
403            ))?;
404        Ok(())
405    }
406
407    pub fn forward_send<T: EnvelopedMessage>(
408        &self,
409        sender_id: ConnectionId,
410        receiver_id: ConnectionId,
411        message: T,
412    ) -> Result<()> {
413        let connection = self.connection_state(receiver_id)?;
414        let message_id = connection
415            .next_message_id
416            .fetch_add(1, atomic::Ordering::SeqCst);
417        connection
418            .outgoing_tx
419            .unbounded_send(proto::Message::Envelope(message.into_envelope(
420                message_id,
421                None,
422                Some(sender_id.0),
423            )))?;
424        Ok(())
425    }
426
427    pub fn respond<T: RequestMessage>(
428        &self,
429        receipt: Receipt<T>,
430        response: T::Response,
431    ) -> Result<()> {
432        let connection = self.connection_state(receipt.sender_id)?;
433        let message_id = connection
434            .next_message_id
435            .fetch_add(1, atomic::Ordering::SeqCst);
436        connection
437            .outgoing_tx
438            .unbounded_send(proto::Message::Envelope(response.into_envelope(
439                message_id,
440                Some(receipt.message_id),
441                None,
442            )))?;
443        Ok(())
444    }
445
446    pub fn respond_with_error<T: RequestMessage>(
447        &self,
448        receipt: Receipt<T>,
449        response: proto::Error,
450    ) -> Result<()> {
451        let connection = self.connection_state(receipt.sender_id)?;
452        let message_id = connection
453            .next_message_id
454            .fetch_add(1, atomic::Ordering::SeqCst);
455        connection
456            .outgoing_tx
457            .unbounded_send(proto::Message::Envelope(response.into_envelope(
458                message_id,
459                Some(receipt.message_id),
460                None,
461            )))?;
462        Ok(())
463    }
464
465    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
466        let connections = self.connections.read();
467        let connection = connections
468            .get(&connection_id)
469            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
470        Ok(connection.clone())
471    }
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use crate::TypedEnvelope;
478    use async_tungstenite::tungstenite::Message as WebSocketMessage;
479    use gpui::TestAppContext;
480
481    #[gpui::test(iterations = 50)]
482    async fn test_request_response(cx: &mut TestAppContext) {
483        let executor = cx.foreground();
484
485        // create 2 clients connected to 1 server
486        let server = Peer::new();
487        let client1 = Peer::new();
488        let client2 = Peer::new();
489
490        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
491            Connection::in_memory(cx.background());
492        let (client1_conn_id, io_task1, client1_incoming) = client1
493            .add_test_connection(client1_to_server_conn, cx.background())
494            .await;
495        let (_, io_task2, server_incoming1) = server
496            .add_test_connection(server_to_client_1_conn, cx.background())
497            .await;
498
499        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
500            Connection::in_memory(cx.background());
501        let (client2_conn_id, io_task3, client2_incoming) = client2
502            .add_test_connection(client2_to_server_conn, cx.background())
503            .await;
504        let (_, io_task4, server_incoming2) = server
505            .add_test_connection(server_to_client_2_conn, cx.background())
506            .await;
507
508        executor.spawn(io_task1).detach();
509        executor.spawn(io_task2).detach();
510        executor.spawn(io_task3).detach();
511        executor.spawn(io_task4).detach();
512        executor
513            .spawn(handle_messages(server_incoming1, server.clone()))
514            .detach();
515        executor
516            .spawn(handle_messages(client1_incoming, client1.clone()))
517            .detach();
518        executor
519            .spawn(handle_messages(server_incoming2, server.clone()))
520            .detach();
521        executor
522            .spawn(handle_messages(client2_incoming, client2.clone()))
523            .detach();
524
525        assert_eq!(
526            client1
527                .request(client1_conn_id, proto::Ping {},)
528                .await
529                .unwrap(),
530            proto::Ack {}
531        );
532
533        assert_eq!(
534            client2
535                .request(client2_conn_id, proto::Ping {},)
536                .await
537                .unwrap(),
538            proto::Ack {}
539        );
540
541        assert_eq!(
542            client1
543                .request(client1_conn_id, proto::Test { id: 1 },)
544                .await
545                .unwrap(),
546            proto::Test { id: 1 }
547        );
548
549        assert_eq!(
550            client2
551                .request(client2_conn_id, proto::Test { id: 2 })
552                .await
553                .unwrap(),
554            proto::Test { id: 2 }
555        );
556
557        client1.disconnect(client1_conn_id);
558        client2.disconnect(client1_conn_id);
559
560        async fn handle_messages(
561            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
562            peer: Arc<Peer>,
563        ) -> Result<()> {
564            while let Some(envelope) = messages.next().await {
565                let envelope = envelope.into_any();
566                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
567                    let receipt = envelope.receipt();
568                    peer.respond(receipt, proto::Ack {})?
569                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
570                {
571                    peer.respond(envelope.receipt(), envelope.payload.clone())?
572                } else {
573                    panic!("unknown message type");
574                }
575            }
576
577            Ok(())
578        }
579    }
580
581    #[gpui::test(iterations = 50)]
582    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
583        let executor = cx.foreground();
584        let server = Peer::new();
585        let client = Peer::new();
586
587        let (client_to_server_conn, server_to_client_conn, _kill) =
588            Connection::in_memory(cx.background());
589        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
590            .add_test_connection(client_to_server_conn, cx.background())
591            .await;
592        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
593            .add_test_connection(server_to_client_conn, cx.background())
594            .await;
595
596        executor.spawn(io_task1).detach();
597        executor.spawn(io_task2).detach();
598
599        executor
600            .spawn(async move {
601                let request = server_incoming
602                    .next()
603                    .await
604                    .unwrap()
605                    .into_any()
606                    .downcast::<TypedEnvelope<proto::Ping>>()
607                    .unwrap();
608
609                server
610                    .send(
611                        server_to_client_conn_id,
612                        proto::Error {
613                            message: "message 1".to_string(),
614                        },
615                    )
616                    .unwrap();
617                server
618                    .send(
619                        server_to_client_conn_id,
620                        proto::Error {
621                            message: "message 2".to_string(),
622                        },
623                    )
624                    .unwrap();
625                server.respond(request.receipt(), proto::Ack {}).unwrap();
626
627                // Prevent the connection from being dropped
628                server_incoming.next().await;
629            })
630            .detach();
631
632        let events = Arc::new(Mutex::new(Vec::new()));
633
634        let response = client.request(client_to_server_conn_id, proto::Ping {});
635        let response_task = executor.spawn({
636            let events = events.clone();
637            async move {
638                response.await.unwrap();
639                events.lock().push("response".to_string());
640            }
641        });
642
643        executor
644            .spawn({
645                let events = events.clone();
646                async move {
647                    let incoming1 = client_incoming
648                        .next()
649                        .await
650                        .unwrap()
651                        .into_any()
652                        .downcast::<TypedEnvelope<proto::Error>>()
653                        .unwrap();
654                    events.lock().push(incoming1.payload.message);
655                    let incoming2 = client_incoming
656                        .next()
657                        .await
658                        .unwrap()
659                        .into_any()
660                        .downcast::<TypedEnvelope<proto::Error>>()
661                        .unwrap();
662                    events.lock().push(incoming2.payload.message);
663
664                    // Prevent the connection from being dropped
665                    client_incoming.next().await;
666                }
667            })
668            .detach();
669
670        response_task.await;
671        assert_eq!(
672            &*events.lock(),
673            &[
674                "message 1".to_string(),
675                "message 2".to_string(),
676                "response".to_string()
677            ]
678        );
679    }
680
681    #[gpui::test(iterations = 50)]
682    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
683        let executor = cx.foreground();
684        let server = Peer::new();
685        let client = Peer::new();
686
687        let (client_to_server_conn, server_to_client_conn, _kill) =
688            Connection::in_memory(cx.background());
689        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
690            .add_test_connection(client_to_server_conn, cx.background())
691            .await;
692        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
693            .add_test_connection(server_to_client_conn, cx.background())
694            .await;
695
696        executor.spawn(io_task1).detach();
697        executor.spawn(io_task2).detach();
698
699        executor
700            .spawn(async move {
701                let request1 = server_incoming
702                    .next()
703                    .await
704                    .unwrap()
705                    .into_any()
706                    .downcast::<TypedEnvelope<proto::Ping>>()
707                    .unwrap();
708                let request2 = server_incoming
709                    .next()
710                    .await
711                    .unwrap()
712                    .into_any()
713                    .downcast::<TypedEnvelope<proto::Ping>>()
714                    .unwrap();
715
716                server
717                    .send(
718                        server_to_client_conn_id,
719                        proto::Error {
720                            message: "message 1".to_string(),
721                        },
722                    )
723                    .unwrap();
724                server
725                    .send(
726                        server_to_client_conn_id,
727                        proto::Error {
728                            message: "message 2".to_string(),
729                        },
730                    )
731                    .unwrap();
732                server.respond(request1.receipt(), proto::Ack {}).unwrap();
733                server.respond(request2.receipt(), proto::Ack {}).unwrap();
734
735                // Prevent the connection from being dropped
736                server_incoming.next().await;
737            })
738            .detach();
739
740        let events = Arc::new(Mutex::new(Vec::new()));
741
742        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
743        let request1_task = executor.spawn(request1);
744        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
745        let request2_task = executor.spawn({
746            let events = events.clone();
747            async move {
748                request2.await.unwrap();
749                events.lock().push("response 2".to_string());
750            }
751        });
752
753        executor
754            .spawn({
755                let events = events.clone();
756                async move {
757                    let incoming1 = client_incoming
758                        .next()
759                        .await
760                        .unwrap()
761                        .into_any()
762                        .downcast::<TypedEnvelope<proto::Error>>()
763                        .unwrap();
764                    events.lock().push(incoming1.payload.message);
765                    let incoming2 = client_incoming
766                        .next()
767                        .await
768                        .unwrap()
769                        .into_any()
770                        .downcast::<TypedEnvelope<proto::Error>>()
771                        .unwrap();
772                    events.lock().push(incoming2.payload.message);
773
774                    // Prevent the connection from being dropped
775                    client_incoming.next().await;
776                }
777            })
778            .detach();
779
780        // Allow the request to make some progress before dropping it.
781        cx.background().simulate_random_delay().await;
782        drop(request1_task);
783
784        request2_task.await;
785        assert_eq!(
786            &*events.lock(),
787            &[
788                "message 1".to_string(),
789                "message 2".to_string(),
790                "response 2".to_string()
791            ]
792        );
793    }
794
795    #[gpui::test(iterations = 50)]
796    async fn test_disconnect(cx: &mut TestAppContext) {
797        let executor = cx.foreground();
798
799        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
800
801        let client = Peer::new();
802        let (connection_id, io_handler, mut incoming) = client
803            .add_test_connection(client_conn, cx.background())
804            .await;
805
806        let (io_ended_tx, io_ended_rx) = oneshot::channel();
807        executor
808            .spawn(async move {
809                io_handler.await.ok();
810                io_ended_tx.send(()).unwrap();
811            })
812            .detach();
813
814        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
815        executor
816            .spawn(async move {
817                incoming.next().await;
818                messages_ended_tx.send(()).unwrap();
819            })
820            .detach();
821
822        client.disconnect(connection_id);
823
824        let _ = io_ended_rx.await;
825        let _ = messages_ended_rx.await;
826        assert!(server_conn
827            .send(WebSocketMessage::Binary(vec![]))
828            .await
829            .is_err());
830    }
831
832    #[gpui::test(iterations = 50)]
833    async fn test_io_error(cx: &mut TestAppContext) {
834        let executor = cx.foreground();
835        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
836
837        let client = Peer::new();
838        let (connection_id, io_handler, mut incoming) = client
839            .add_test_connection(client_conn, cx.background())
840            .await;
841        executor.spawn(io_handler).detach();
842        executor
843            .spawn(async move { incoming.next().await })
844            .detach();
845
846        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
847        let _request = server_conn.rx.next().await.unwrap().unwrap();
848
849        drop(server_conn);
850        assert_eq!(
851            response.await.unwrap_err().to_string(),
852            "connection was closed"
853        );
854    }
855}