peer.rs

  1use super::{
  2    proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
  3    Connection,
  4};
  5use anyhow::{anyhow, Context, Result};
  6use collections::HashMap;
  7use futures::{
  8    channel::{mpsc, oneshot},
  9    stream::BoxStream,
 10    FutureExt, SinkExt, StreamExt,
 11};
 12use parking_lot::{Mutex, RwLock};
 13use serde::{ser::SerializeStruct, Serialize};
 14use std::{fmt, sync::atomic::Ordering::SeqCst};
 15use std::{
 16    future::Future,
 17    marker::PhantomData,
 18    sync::{
 19        atomic::{self, AtomicU32},
 20        Arc,
 21    },
 22    time::Duration,
 23};
 24use tracing::instrument;
 25
 26#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
 27pub struct ConnectionId {
 28    pub owner_id: u32,
 29    pub id: u32,
 30}
 31
 32impl Into<PeerId> for ConnectionId {
 33    fn into(self) -> PeerId {
 34        PeerId {
 35            owner_id: self.owner_id,
 36            id: self.id,
 37        }
 38    }
 39}
 40
 41impl From<PeerId> for ConnectionId {
 42    fn from(peer_id: PeerId) -> Self {
 43        Self {
 44            owner_id: peer_id.owner_id,
 45            id: peer_id.id,
 46        }
 47    }
 48}
 49
 50impl fmt::Display for ConnectionId {
 51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 52        write!(f, "{}/{}", self.owner_id, self.id)
 53    }
 54}
 55
 56pub struct Receipt<T> {
 57    pub sender_id: ConnectionId,
 58    pub message_id: u32,
 59    payload_type: PhantomData<T>,
 60}
 61
 62impl<T> Clone for Receipt<T> {
 63    fn clone(&self) -> Self {
 64        Self {
 65            sender_id: self.sender_id,
 66            message_id: self.message_id,
 67            payload_type: PhantomData,
 68        }
 69    }
 70}
 71
 72impl<T> Copy for Receipt<T> {}
 73
 74pub struct TypedEnvelope<T> {
 75    pub sender_id: ConnectionId,
 76    pub original_sender_id: Option<PeerId>,
 77    pub message_id: u32,
 78    pub payload: T,
 79}
 80
 81impl<T> TypedEnvelope<T> {
 82    pub fn original_sender_id(&self) -> Result<PeerId> {
 83        self.original_sender_id
 84            .ok_or_else(|| anyhow!("missing original_sender_id"))
 85    }
 86}
 87
 88impl<T: RequestMessage> TypedEnvelope<T> {
 89    pub fn receipt(&self) -> Receipt<T> {
 90        Receipt {
 91            sender_id: self.sender_id,
 92            message_id: self.message_id,
 93            payload_type: PhantomData,
 94        }
 95    }
 96}
 97
 98pub struct Peer {
 99    epoch: AtomicU32,
100    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
101    next_connection_id: AtomicU32,
102}
103
104#[derive(Clone, Serialize)]
105pub struct ConnectionState {
106    #[serde(skip)]
107    outgoing_tx: mpsc::UnboundedSender<proto::Message>,
108    next_message_id: Arc<AtomicU32>,
109    #[allow(clippy::type_complexity)]
110    #[serde(skip)]
111    response_channels:
112        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
113}
114
115const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
116const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
117pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
118
119impl Peer {
120    pub fn new(epoch: u32) -> Arc<Self> {
121        Arc::new(Self {
122            epoch: AtomicU32::new(epoch),
123            connections: Default::default(),
124            next_connection_id: Default::default(),
125        })
126    }
127
128    pub fn epoch(&self) -> u32 {
129        self.epoch.load(SeqCst)
130    }
131
132    #[instrument(skip_all)]
133    pub fn add_connection<F, Fut, Out>(
134        self: &Arc<Self>,
135        connection: Connection,
136        create_timer: F,
137    ) -> (
138        ConnectionId,
139        impl Future<Output = anyhow::Result<()>> + Send,
140        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
141    )
142    where
143        F: Send + Fn(Duration) -> Fut,
144        Fut: Send + Future<Output = Out>,
145        Out: Send,
146    {
147        // For outgoing messages, use an unbounded channel so that application code
148        // can always send messages without yielding. For incoming messages, use a
149        // bounded channel so that other peers will receive backpressure if they send
150        // messages faster than this peer can process them.
151        #[cfg(any(test, feature = "test-support"))]
152        const INCOMING_BUFFER_SIZE: usize = 1;
153        #[cfg(not(any(test, feature = "test-support")))]
154        const INCOMING_BUFFER_SIZE: usize = 64;
155        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
156        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
157
158        let connection_id = ConnectionId {
159            owner_id: self.epoch.load(SeqCst),
160            id: self.next_connection_id.fetch_add(1, SeqCst),
161        };
162        let connection_state = ConnectionState {
163            outgoing_tx,
164            next_message_id: Default::default(),
165            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
166        };
167        let mut writer = MessageStream::new(connection.tx);
168        let mut reader = MessageStream::new(connection.rx);
169
170        let this = self.clone();
171        let response_channels = connection_state.response_channels.clone();
172        let handle_io = async move {
173            tracing::debug!(%connection_id, "handle io future: start");
174
175            let _end_connection = util::defer(|| {
176                response_channels.lock().take();
177                this.connections.write().remove(&connection_id);
178                tracing::debug!(%connection_id, "handle io future: end");
179            });
180
181            // Send messages on this frequency so the connection isn't closed.
182            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
183            futures::pin_mut!(keepalive_timer);
184
185            // Disconnect if we don't receive messages at least this frequently.
186            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
187            futures::pin_mut!(receive_timeout);
188
189            loop {
190                tracing::debug!(%connection_id, "outer loop iteration start");
191                let read_message = reader.read().fuse();
192                futures::pin_mut!(read_message);
193
194                loop {
195                    tracing::debug!(%connection_id, "inner loop iteration start");
196                    futures::select_biased! {
197                        outgoing = outgoing_rx.next().fuse() => match outgoing {
198                            Some(outgoing) => {
199                                tracing::debug!(%connection_id, "outgoing rpc message: writing");
200                                futures::select_biased! {
201                                    result = writer.write(outgoing).fuse() => {
202                                        tracing::debug!(%connection_id, "outgoing rpc message: done writing");
203                                        result.context("failed to write RPC message")?;
204                                        tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
205                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
206                                    }
207                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
208                                        tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
209                                        Err(anyhow!("timed out writing message"))?;
210                                    }
211                                }
212                            }
213                            None => {
214                                tracing::debug!(%connection_id, "outgoing rpc message: channel closed");
215                                return Ok(())
216                            },
217                        },
218                        _ = keepalive_timer => {
219                            tracing::debug!(%connection_id, "keepalive interval: pinging");
220                            futures::select_biased! {
221                                result = writer.write(proto::Message::Ping).fuse() => {
222                                    tracing::debug!(%connection_id, "keepalive interval: done pinging");
223                                    result.context("failed to send keepalive")?;
224                                    tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
225                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
226                                }
227                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
228                                    tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
229                                    Err(anyhow!("timed out sending keepalive"))?;
230                                }
231                            }
232                        }
233                        incoming = read_message => {
234                            let incoming = incoming.context("error reading rpc message from socket")?;
235                            tracing::debug!(%connection_id, "incoming rpc message: received");
236                            tracing::debug!(%connection_id, "receive timeout: resetting");
237                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
238                            if let proto::Message::Envelope(incoming) = incoming {
239                                tracing::debug!(%connection_id, "incoming rpc message: processing");
240                                futures::select_biased! {
241                                    result = incoming_tx.send(incoming).fuse() => match result {
242                                        Ok(_) => {
243                                            tracing::debug!(%connection_id, "incoming rpc message: processed");
244                                        }
245                                        Err(_) => {
246                                            tracing::debug!(%connection_id, "incoming rpc message: channel closed");
247                                            return Ok(())
248                                        }
249                                    },
250                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
251                                        tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
252                                        Err(anyhow!("timed out processing incoming message"))?
253                                    }
254                                }
255                            }
256                            break;
257                        },
258                        _ = receive_timeout => {
259                            tracing::debug!(%connection_id, "receive timeout: delay between messages too long");
260                            Err(anyhow!("delay between messages too long"))?
261                        }
262                    }
263                }
264            }
265        };
266
267        let response_channels = connection_state.response_channels.clone();
268        self.connections
269            .write()
270            .insert(connection_id, connection_state);
271
272        let incoming_rx = incoming_rx.filter_map(move |incoming| {
273            let response_channels = response_channels.clone();
274            async move {
275                let message_id = incoming.id;
276                tracing::debug!(?incoming, "incoming message future: start");
277                let _end = util::defer(move || {
278                    tracing::debug!(%connection_id, message_id, "incoming message future: end");
279                });
280
281                if let Some(responding_to) = incoming.responding_to {
282                    tracing::debug!(
283                        %connection_id,
284                        message_id,
285                        responding_to,
286                        "incoming response: received"
287                    );
288                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
289                    if let Some(tx) = channel {
290                        let requester_resumed = oneshot::channel();
291                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
292                            tracing::debug!(
293                                %connection_id,
294                                message_id,
295                                responding_to = responding_to,
296                                ?error,
297                                "incoming response: request future dropped",
298                            );
299                        }
300
301                        tracing::debug!(
302                            %connection_id,
303                            message_id,
304                            responding_to,
305                            "incoming response: waiting to resume requester"
306                        );
307                        let _ = requester_resumed.1.await;
308                        tracing::debug!(
309                            %connection_id,
310                            message_id,
311                            responding_to,
312                            "incoming response: requester resumed"
313                        );
314                    } else {
315                        tracing::warn!(
316                            %connection_id,
317                            message_id,
318                            responding_to,
319                            "incoming response: unknown request"
320                        );
321                    }
322
323                    None
324                } else {
325                    tracing::debug!(%connection_id, message_id, "incoming message: received");
326                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
327                        tracing::error!(
328                            %connection_id,
329                            message_id,
330                            "unable to construct a typed envelope"
331                        );
332                        None
333                    })
334                }
335            }
336        });
337        (connection_id, handle_io, incoming_rx.boxed())
338    }
339
340    #[cfg(any(test, feature = "test-support"))]
341    pub fn add_test_connection(
342        self: &Arc<Self>,
343        connection: Connection,
344        executor: Arc<gpui::executor::Background>,
345    ) -> (
346        ConnectionId,
347        impl Future<Output = anyhow::Result<()>> + Send,
348        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
349    ) {
350        let executor = executor.clone();
351        self.add_connection(connection, move |duration| executor.timer(duration))
352    }
353
354    pub fn disconnect(&self, connection_id: ConnectionId) {
355        self.connections.write().remove(&connection_id);
356    }
357
358    pub fn reset(&self, epoch: u32) {
359        self.teardown();
360        self.next_connection_id.store(0, SeqCst);
361        self.epoch.store(epoch, SeqCst);
362    }
363
364    pub fn teardown(&self) {
365        self.connections.write().clear();
366    }
367
368    pub fn request<T: RequestMessage>(
369        &self,
370        receiver_id: ConnectionId,
371        request: T,
372    ) -> impl Future<Output = Result<T::Response>> {
373        self.request_internal(None, receiver_id, request)
374    }
375
376    pub fn forward_request<T: RequestMessage>(
377        &self,
378        sender_id: ConnectionId,
379        receiver_id: ConnectionId,
380        request: T,
381    ) -> impl Future<Output = Result<T::Response>> {
382        self.request_internal(Some(sender_id), receiver_id, request)
383    }
384
385    pub fn request_internal<T: RequestMessage>(
386        &self,
387        original_sender_id: Option<ConnectionId>,
388        receiver_id: ConnectionId,
389        request: T,
390    ) -> impl Future<Output = Result<T::Response>> {
391        let (tx, rx) = oneshot::channel();
392        let send = self.connection_state(receiver_id).and_then(|connection| {
393            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
394            connection
395                .response_channels
396                .lock()
397                .as_mut()
398                .ok_or_else(|| anyhow!("connection was closed"))?
399                .insert(message_id, tx);
400            connection
401                .outgoing_tx
402                .unbounded_send(proto::Message::Envelope(request.into_envelope(
403                    message_id,
404                    None,
405                    original_sender_id.map(Into::into),
406                )))
407                .map_err(|_| anyhow!("connection was closed"))?;
408            Ok(())
409        });
410        async move {
411            send?;
412            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
413            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
414                Err(anyhow!(
415                    "RPC request {} failed - {}",
416                    T::NAME,
417                    error.message
418                ))
419            } else {
420                T::Response::from_envelope(response)
421                    .ok_or_else(|| anyhow!("received response of the wrong type"))
422            }
423        }
424    }
425
426    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
427        let connection = self.connection_state(receiver_id)?;
428        let message_id = connection
429            .next_message_id
430            .fetch_add(1, atomic::Ordering::SeqCst);
431        connection
432            .outgoing_tx
433            .unbounded_send(proto::Message::Envelope(
434                message.into_envelope(message_id, None, None),
435            ))?;
436        Ok(())
437    }
438
439    pub fn forward_send<T: EnvelopedMessage>(
440        &self,
441        sender_id: ConnectionId,
442        receiver_id: ConnectionId,
443        message: T,
444    ) -> Result<()> {
445        let connection = self.connection_state(receiver_id)?;
446        let message_id = connection
447            .next_message_id
448            .fetch_add(1, atomic::Ordering::SeqCst);
449        connection
450            .outgoing_tx
451            .unbounded_send(proto::Message::Envelope(message.into_envelope(
452                message_id,
453                None,
454                Some(sender_id.into()),
455            )))?;
456        Ok(())
457    }
458
459    pub fn respond<T: RequestMessage>(
460        &self,
461        receipt: Receipt<T>,
462        response: T::Response,
463    ) -> Result<()> {
464        let connection = self.connection_state(receipt.sender_id)?;
465        let message_id = connection
466            .next_message_id
467            .fetch_add(1, atomic::Ordering::SeqCst);
468        connection
469            .outgoing_tx
470            .unbounded_send(proto::Message::Envelope(response.into_envelope(
471                message_id,
472                Some(receipt.message_id),
473                None,
474            )))?;
475        Ok(())
476    }
477
478    pub fn respond_with_error<T: RequestMessage>(
479        &self,
480        receipt: Receipt<T>,
481        response: proto::Error,
482    ) -> Result<()> {
483        let connection = self.connection_state(receipt.sender_id)?;
484        let message_id = connection
485            .next_message_id
486            .fetch_add(1, atomic::Ordering::SeqCst);
487        connection
488            .outgoing_tx
489            .unbounded_send(proto::Message::Envelope(response.into_envelope(
490                message_id,
491                Some(receipt.message_id),
492                None,
493            )))?;
494        Ok(())
495    }
496
497    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
498        let connections = self.connections.read();
499        let connection = connections
500            .get(&connection_id)
501            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
502        Ok(connection.clone())
503    }
504}
505
506impl Serialize for Peer {
507    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
508    where
509        S: serde::Serializer,
510    {
511        let mut state = serializer.serialize_struct("Peer", 2)?;
512        state.serialize_field("connections", &*self.connections.read())?;
513        state.end()
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520    use crate::TypedEnvelope;
521    use async_tungstenite::tungstenite::Message as WebSocketMessage;
522    use gpui::TestAppContext;
523
524    #[ctor::ctor]
525    fn init_logger() {
526        if std::env::var("RUST_LOG").is_ok() {
527            env_logger::init();
528        }
529    }
530
531    #[gpui::test(iterations = 50)]
532    async fn test_request_response(cx: &mut TestAppContext) {
533        let executor = cx.foreground();
534
535        // create 2 clients connected to 1 server
536        let server = Peer::new(0);
537        let client1 = Peer::new(0);
538        let client2 = Peer::new(0);
539
540        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
541            Connection::in_memory(cx.background());
542        let (client1_conn_id, io_task1, client1_incoming) =
543            client1.add_test_connection(client1_to_server_conn, cx.background());
544        let (_, io_task2, server_incoming1) =
545            server.add_test_connection(server_to_client_1_conn, cx.background());
546
547        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
548            Connection::in_memory(cx.background());
549        let (client2_conn_id, io_task3, client2_incoming) =
550            client2.add_test_connection(client2_to_server_conn, cx.background());
551        let (_, io_task4, server_incoming2) =
552            server.add_test_connection(server_to_client_2_conn, cx.background());
553
554        executor.spawn(io_task1).detach();
555        executor.spawn(io_task2).detach();
556        executor.spawn(io_task3).detach();
557        executor.spawn(io_task4).detach();
558        executor
559            .spawn(handle_messages(server_incoming1, server.clone()))
560            .detach();
561        executor
562            .spawn(handle_messages(client1_incoming, client1.clone()))
563            .detach();
564        executor
565            .spawn(handle_messages(server_incoming2, server.clone()))
566            .detach();
567        executor
568            .spawn(handle_messages(client2_incoming, client2.clone()))
569            .detach();
570
571        assert_eq!(
572            client1
573                .request(client1_conn_id, proto::Ping {},)
574                .await
575                .unwrap(),
576            proto::Ack {}
577        );
578
579        assert_eq!(
580            client2
581                .request(client2_conn_id, proto::Ping {},)
582                .await
583                .unwrap(),
584            proto::Ack {}
585        );
586
587        assert_eq!(
588            client1
589                .request(client1_conn_id, proto::Test { id: 1 },)
590                .await
591                .unwrap(),
592            proto::Test { id: 1 }
593        );
594
595        assert_eq!(
596            client2
597                .request(client2_conn_id, proto::Test { id: 2 })
598                .await
599                .unwrap(),
600            proto::Test { id: 2 }
601        );
602
603        client1.disconnect(client1_conn_id);
604        client2.disconnect(client1_conn_id);
605
606        async fn handle_messages(
607            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
608            peer: Arc<Peer>,
609        ) -> Result<()> {
610            while let Some(envelope) = messages.next().await {
611                let envelope = envelope.into_any();
612                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
613                    let receipt = envelope.receipt();
614                    peer.respond(receipt, proto::Ack {})?
615                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
616                {
617                    peer.respond(envelope.receipt(), envelope.payload.clone())?
618                } else {
619                    panic!("unknown message type");
620                }
621            }
622
623            Ok(())
624        }
625    }
626
627    #[gpui::test(iterations = 50)]
628    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
629        let executor = cx.foreground();
630        let server = Peer::new(0);
631        let client = Peer::new(0);
632
633        let (client_to_server_conn, server_to_client_conn, _kill) =
634            Connection::in_memory(cx.background());
635        let (client_to_server_conn_id, io_task1, mut client_incoming) =
636            client.add_test_connection(client_to_server_conn, cx.background());
637        let (server_to_client_conn_id, io_task2, mut server_incoming) =
638            server.add_test_connection(server_to_client_conn, cx.background());
639
640        executor.spawn(io_task1).detach();
641        executor.spawn(io_task2).detach();
642
643        executor
644            .spawn(async move {
645                let request = server_incoming
646                    .next()
647                    .await
648                    .unwrap()
649                    .into_any()
650                    .downcast::<TypedEnvelope<proto::Ping>>()
651                    .unwrap();
652
653                server
654                    .send(
655                        server_to_client_conn_id,
656                        proto::Error {
657                            message: "message 1".to_string(),
658                        },
659                    )
660                    .unwrap();
661                server
662                    .send(
663                        server_to_client_conn_id,
664                        proto::Error {
665                            message: "message 2".to_string(),
666                        },
667                    )
668                    .unwrap();
669                server.respond(request.receipt(), proto::Ack {}).unwrap();
670
671                // Prevent the connection from being dropped
672                server_incoming.next().await;
673            })
674            .detach();
675
676        let events = Arc::new(Mutex::new(Vec::new()));
677
678        let response = client.request(client_to_server_conn_id, proto::Ping {});
679        let response_task = executor.spawn({
680            let events = events.clone();
681            async move {
682                response.await.unwrap();
683                events.lock().push("response".to_string());
684            }
685        });
686
687        executor
688            .spawn({
689                let events = events.clone();
690                async move {
691                    let incoming1 = client_incoming
692                        .next()
693                        .await
694                        .unwrap()
695                        .into_any()
696                        .downcast::<TypedEnvelope<proto::Error>>()
697                        .unwrap();
698                    events.lock().push(incoming1.payload.message);
699                    let incoming2 = client_incoming
700                        .next()
701                        .await
702                        .unwrap()
703                        .into_any()
704                        .downcast::<TypedEnvelope<proto::Error>>()
705                        .unwrap();
706                    events.lock().push(incoming2.payload.message);
707
708                    // Prevent the connection from being dropped
709                    client_incoming.next().await;
710                }
711            })
712            .detach();
713
714        response_task.await;
715        assert_eq!(
716            &*events.lock(),
717            &[
718                "message 1".to_string(),
719                "message 2".to_string(),
720                "response".to_string()
721            ]
722        );
723    }
724
725    #[gpui::test(iterations = 50)]
726    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
727        let executor = cx.foreground();
728        let server = Peer::new(0);
729        let client = Peer::new(0);
730
731        let (client_to_server_conn, server_to_client_conn, _kill) =
732            Connection::in_memory(cx.background());
733        let (client_to_server_conn_id, io_task1, mut client_incoming) =
734            client.add_test_connection(client_to_server_conn, cx.background());
735        let (server_to_client_conn_id, io_task2, mut server_incoming) =
736            server.add_test_connection(server_to_client_conn, cx.background());
737
738        executor.spawn(io_task1).detach();
739        executor.spawn(io_task2).detach();
740
741        executor
742            .spawn(async move {
743                let request1 = server_incoming
744                    .next()
745                    .await
746                    .unwrap()
747                    .into_any()
748                    .downcast::<TypedEnvelope<proto::Ping>>()
749                    .unwrap();
750                let request2 = server_incoming
751                    .next()
752                    .await
753                    .unwrap()
754                    .into_any()
755                    .downcast::<TypedEnvelope<proto::Ping>>()
756                    .unwrap();
757
758                server
759                    .send(
760                        server_to_client_conn_id,
761                        proto::Error {
762                            message: "message 1".to_string(),
763                        },
764                    )
765                    .unwrap();
766                server
767                    .send(
768                        server_to_client_conn_id,
769                        proto::Error {
770                            message: "message 2".to_string(),
771                        },
772                    )
773                    .unwrap();
774                server.respond(request1.receipt(), proto::Ack {}).unwrap();
775                server.respond(request2.receipt(), proto::Ack {}).unwrap();
776
777                // Prevent the connection from being dropped
778                server_incoming.next().await;
779            })
780            .detach();
781
782        let events = Arc::new(Mutex::new(Vec::new()));
783
784        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
785        let request1_task = executor.spawn(request1);
786        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
787        let request2_task = executor.spawn({
788            let events = events.clone();
789            async move {
790                request2.await.unwrap();
791                events.lock().push("response 2".to_string());
792            }
793        });
794
795        executor
796            .spawn({
797                let events = events.clone();
798                async move {
799                    let incoming1 = client_incoming
800                        .next()
801                        .await
802                        .unwrap()
803                        .into_any()
804                        .downcast::<TypedEnvelope<proto::Error>>()
805                        .unwrap();
806                    events.lock().push(incoming1.payload.message);
807                    let incoming2 = client_incoming
808                        .next()
809                        .await
810                        .unwrap()
811                        .into_any()
812                        .downcast::<TypedEnvelope<proto::Error>>()
813                        .unwrap();
814                    events.lock().push(incoming2.payload.message);
815
816                    // Prevent the connection from being dropped
817                    client_incoming.next().await;
818                }
819            })
820            .detach();
821
822        // Allow the request to make some progress before dropping it.
823        cx.background().simulate_random_delay().await;
824        drop(request1_task);
825
826        request2_task.await;
827        assert_eq!(
828            &*events.lock(),
829            &[
830                "message 1".to_string(),
831                "message 2".to_string(),
832                "response 2".to_string()
833            ]
834        );
835    }
836
837    #[gpui::test(iterations = 50)]
838    async fn test_disconnect(cx: &mut TestAppContext) {
839        let executor = cx.foreground();
840
841        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
842
843        let client = Peer::new(0);
844        let (connection_id, io_handler, mut incoming) =
845            client.add_test_connection(client_conn, cx.background());
846
847        let (io_ended_tx, io_ended_rx) = oneshot::channel();
848        executor
849            .spawn(async move {
850                io_handler.await.ok();
851                io_ended_tx.send(()).unwrap();
852            })
853            .detach();
854
855        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
856        executor
857            .spawn(async move {
858                incoming.next().await;
859                messages_ended_tx.send(()).unwrap();
860            })
861            .detach();
862
863        client.disconnect(connection_id);
864
865        let _ = io_ended_rx.await;
866        let _ = messages_ended_rx.await;
867        assert!(server_conn
868            .send(WebSocketMessage::Binary(vec![]))
869            .await
870            .is_err());
871    }
872
873    #[gpui::test(iterations = 50)]
874    async fn test_io_error(cx: &mut TestAppContext) {
875        let executor = cx.foreground();
876        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
877
878        let client = Peer::new(0);
879        let (connection_id, io_handler, mut incoming) =
880            client.add_test_connection(client_conn, cx.background());
881        executor.spawn(io_handler).detach();
882        executor
883            .spawn(async move { incoming.next().await })
884            .detach();
885
886        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
887        let _request = server_conn.rx.next().await.unwrap().unwrap();
888
889        drop(server_conn);
890        assert_eq!(
891            response.await.unwrap_err().to_string(),
892            "connection was closed"
893        );
894    }
895}