peer.rs

  1use super::{
  2    proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
  3    Connection,
  4};
  5use anyhow::{anyhow, Context, Result};
  6use collections::HashMap;
  7use futures::{
  8    channel::{mpsc, oneshot},
  9    stream::BoxStream,
 10    FutureExt, SinkExt, StreamExt,
 11};
 12use parking_lot::{Mutex, RwLock};
 13use serde::{ser::SerializeStruct, Serialize};
 14use smol_timeout::TimeoutExt;
 15use std::sync::atomic::Ordering::SeqCst;
 16use std::{
 17    fmt,
 18    future::Future,
 19    marker::PhantomData,
 20    sync::{
 21        atomic::{self, AtomicU32},
 22        Arc,
 23    },
 24    time::Duration,
 25};
 26use tracing::instrument;
 27
 28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Serialize)]
 29pub struct ConnectionId(pub u32);
 30
 31impl fmt::Display for ConnectionId {
 32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 33        self.0.fmt(f)
 34    }
 35}
 36
 37#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 38pub struct PeerId(pub u32);
 39
 40impl fmt::Display for PeerId {
 41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 42        self.0.fmt(f)
 43    }
 44}
 45
 46pub struct Receipt<T> {
 47    pub sender_id: ConnectionId,
 48    pub message_id: u32,
 49    payload_type: PhantomData<T>,
 50}
 51
 52impl<T> Clone for Receipt<T> {
 53    fn clone(&self) -> Self {
 54        Self {
 55            sender_id: self.sender_id,
 56            message_id: self.message_id,
 57            payload_type: PhantomData,
 58        }
 59    }
 60}
 61
 62impl<T> Copy for Receipt<T> {}
 63
 64pub struct TypedEnvelope<T> {
 65    pub sender_id: ConnectionId,
 66    pub original_sender_id: Option<PeerId>,
 67    pub message_id: u32,
 68    pub payload: T,
 69}
 70
 71impl<T> TypedEnvelope<T> {
 72    pub fn original_sender_id(&self) -> Result<PeerId> {
 73        self.original_sender_id
 74            .ok_or_else(|| anyhow!("missing original_sender_id"))
 75    }
 76}
 77
 78impl<T: RequestMessage> TypedEnvelope<T> {
 79    pub fn receipt(&self) -> Receipt<T> {
 80        Receipt {
 81            sender_id: self.sender_id,
 82            message_id: self.message_id,
 83            payload_type: PhantomData,
 84        }
 85    }
 86}
 87
 88pub struct Peer {
 89    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 90    next_connection_id: AtomicU32,
 91}
 92
 93#[derive(Clone, Serialize)]
 94pub struct ConnectionState {
 95    #[serde(skip)]
 96    outgoing_tx: mpsc::UnboundedSender<proto::Message>,
 97    next_message_id: Arc<AtomicU32>,
 98    #[serde(skip)]
 99    response_channels:
100        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
101}
102
103const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
104const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
105pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
106
107impl Peer {
108    pub fn new() -> Arc<Self> {
109        Arc::new(Self {
110            connections: Default::default(),
111            next_connection_id: Default::default(),
112        })
113    }
114
115    #[instrument(skip_all)]
116    pub async fn add_connection<F, Fut, Out>(
117        self: &Arc<Self>,
118        connection: Connection,
119        create_timer: F,
120    ) -> (
121        ConnectionId,
122        impl Future<Output = anyhow::Result<()>> + Send,
123        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
124    )
125    where
126        F: Send + Fn(Duration) -> Fut,
127        Fut: Send + Future<Output = Out>,
128        Out: Send,
129    {
130        // For outgoing messages, use an unbounded channel so that application code
131        // can always send messages without yielding. For incoming messages, use a
132        // bounded channel so that other peers will receive backpressure if they send
133        // messages faster than this peer can process them.
134        #[cfg(any(test, feature = "test-support"))]
135        const INCOMING_BUFFER_SIZE: usize = 1;
136        #[cfg(not(any(test, feature = "test-support")))]
137        const INCOMING_BUFFER_SIZE: usize = 64;
138        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
139        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
140
141        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
142        let connection_state = ConnectionState {
143            outgoing_tx: outgoing_tx.clone(),
144            next_message_id: Default::default(),
145            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
146        };
147        let mut writer = MessageStream::new(connection.tx);
148        let mut reader = MessageStream::new(connection.rx);
149
150        let this = self.clone();
151        let response_channels = connection_state.response_channels.clone();
152        let handle_io = async move {
153            tracing::debug!(%connection_id, "handle io future: start");
154
155            let _end_connection = util::defer(|| {
156                response_channels.lock().take();
157                this.connections.write().remove(&connection_id);
158                tracing::debug!(%connection_id, "handle io future: end");
159            });
160
161            // Send messages on this frequency so the connection isn't closed.
162            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
163            futures::pin_mut!(keepalive_timer);
164
165            // Disconnect if we don't receive messages at least this frequently.
166            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
167            futures::pin_mut!(receive_timeout);
168
169            loop {
170                tracing::debug!(%connection_id, "outer loop iteration start");
171                let read_message = reader.read().fuse();
172                futures::pin_mut!(read_message);
173
174                loop {
175                    tracing::debug!(%connection_id, "inner loop iteration start");
176                    futures::select_biased! {
177                        outgoing = outgoing_rx.next().fuse() => match outgoing {
178                            Some(outgoing) => {
179                                tracing::debug!(%connection_id, "outgoing rpc message: writing");
180                                if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
181                                    tracing::debug!(%connection_id, "outgoing rpc message: done writing");
182                                    result.context("failed to write RPC message")?;
183                                    tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
184                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
185                                } else {
186                                    tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
187                                    Err(anyhow!("timed out writing message"))?;
188                                }
189                            }
190                            None => {
191                                tracing::debug!(%connection_id, "outgoing rpc message: channel closed");
192                                return Ok(())
193                            },
194                        },
195                        incoming = read_message => {
196                            let incoming = incoming.context("error reading rpc message from socket")?;
197                            tracing::debug!(%connection_id, "incoming rpc message: received");
198                            tracing::debug!(%connection_id, "receive timeout: resetting");
199                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
200                            if let proto::Message::Envelope(incoming) = incoming {
201                                tracing::debug!(%connection_id, "incoming rpc message: processing");
202                                match incoming_tx.send(incoming).timeout(RECEIVE_TIMEOUT).await {
203                                    Some(Ok(_)) => {
204                                        tracing::debug!(%connection_id, "incoming rpc message: processed");
205                                    },
206                                    Some(Err(_)) => {
207                                        tracing::debug!(%connection_id, "incoming rpc message: channel closed");
208                                        return Ok(())
209                                    },
210                                    None => {
211                                        tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
212                                        Err(anyhow!("timed out processing incoming message"))?
213                                    },
214                                }
215                            }
216                            break;
217                        },
218                        _ = keepalive_timer => {
219                            tracing::debug!(%connection_id, "keepalive interval: pinging");
220                            if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
221                                tracing::debug!(%connection_id, "keepalive interval: done pinging");
222                                result.context("failed to send keepalive")?;
223                                tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
224                                keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
225                            } else {
226                                tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
227                                Err(anyhow!("timed out sending keepalive"))?;
228                            }
229                        }
230                        _ = receive_timeout => {
231                            tracing::debug!(%connection_id, "receive timeout: delay between messages too long");
232                            Err(anyhow!("delay between messages too long"))?
233                        }
234                    }
235                }
236            }
237        };
238
239        let response_channels = connection_state.response_channels.clone();
240        self.connections
241            .write()
242            .insert(connection_id, connection_state);
243
244        let incoming_rx = incoming_rx.filter_map(move |incoming| {
245            let response_channels = response_channels.clone();
246            async move {
247                let message_id = incoming.id;
248                tracing::debug!(?incoming, "incoming message future: start");
249                let _end = util::defer(move || {
250                    tracing::debug!(
251                        %connection_id,
252                        message_id,
253                        "incoming message future: end"
254                    );
255                });
256
257                if let Some(responding_to) = incoming.responding_to {
258                    tracing::debug!(
259                        %connection_id,
260                        message_id,
261                        responding_to,
262                        "incoming response: received"
263                    );
264                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
265                    if let Some(tx) = channel {
266                        let requester_resumed = oneshot::channel();
267                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
268                            tracing::debug!(
269                                %connection_id,
270                                message_id,
271                                responding_to = responding_to,
272                                ?error,
273                                "incoming response: request future dropped",
274                            );
275                        }
276
277                        tracing::debug!(
278                            %connection_id,
279                            message_id,
280                            responding_to,
281                            "incoming response: waiting to resume requester"
282                        );
283                        let _ = requester_resumed.1.await;
284                        tracing::debug!(
285                            %connection_id,
286                            message_id,
287                            responding_to,
288                            "incoming response: requester resumed"
289                        );
290                    } else {
291                        tracing::warn!(
292                            %connection_id,
293                            message_id,
294                            responding_to,
295                            "incoming response: unknown request"
296                        );
297                    }
298
299                    None
300                } else {
301                    tracing::debug!(
302                        %connection_id,
303                        message_id,
304                        "incoming message: received"
305                    );
306                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
307                        tracing::error!(
308                            %connection_id,
309                            message_id,
310                            "unable to construct a typed envelope"
311                        );
312                        None
313                    })
314                }
315            }
316        });
317        (connection_id, handle_io, incoming_rx.boxed())
318    }
319
320    #[cfg(any(test, feature = "test-support"))]
321    pub async fn add_test_connection(
322        self: &Arc<Self>,
323        connection: Connection,
324        executor: Arc<gpui::executor::Background>,
325    ) -> (
326        ConnectionId,
327        impl Future<Output = anyhow::Result<()>> + Send,
328        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
329    ) {
330        let executor = executor.clone();
331        self.add_connection(connection, move |duration| executor.timer(duration))
332            .await
333    }
334
335    pub fn disconnect(&self, connection_id: ConnectionId) {
336        self.connections.write().remove(&connection_id);
337    }
338
339    pub fn reset(&self) {
340        self.connections.write().clear();
341    }
342
343    pub fn request<T: RequestMessage>(
344        &self,
345        receiver_id: ConnectionId,
346        request: T,
347    ) -> impl Future<Output = Result<T::Response>> {
348        self.request_internal(None, receiver_id, request)
349    }
350
351    pub fn forward_request<T: RequestMessage>(
352        &self,
353        sender_id: ConnectionId,
354        receiver_id: ConnectionId,
355        request: T,
356    ) -> impl Future<Output = Result<T::Response>> {
357        self.request_internal(Some(sender_id), receiver_id, request)
358    }
359
360    pub fn request_internal<T: RequestMessage>(
361        &self,
362        original_sender_id: Option<ConnectionId>,
363        receiver_id: ConnectionId,
364        request: T,
365    ) -> impl Future<Output = Result<T::Response>> {
366        let (tx, rx) = oneshot::channel();
367        let send = self.connection_state(receiver_id).and_then(|connection| {
368            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
369            connection
370                .response_channels
371                .lock()
372                .as_mut()
373                .ok_or_else(|| anyhow!("connection was closed"))?
374                .insert(message_id, tx);
375            connection
376                .outgoing_tx
377                .unbounded_send(proto::Message::Envelope(request.into_envelope(
378                    message_id,
379                    None,
380                    original_sender_id.map(|id| id.0),
381                )))
382                .map_err(|_| anyhow!("connection was closed"))?;
383            Ok(())
384        });
385        async move {
386            send?;
387            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
388            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
389                Err(anyhow!("RPC request failed - {}", error.message))
390            } else {
391                T::Response::from_envelope(response)
392                    .ok_or_else(|| anyhow!("received response of the wrong type"))
393            }
394        }
395    }
396
397    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
398        let connection = self.connection_state(receiver_id)?;
399        let message_id = connection
400            .next_message_id
401            .fetch_add(1, atomic::Ordering::SeqCst);
402        connection
403            .outgoing_tx
404            .unbounded_send(proto::Message::Envelope(
405                message.into_envelope(message_id, None, None),
406            ))?;
407        Ok(())
408    }
409
410    pub fn forward_send<T: EnvelopedMessage>(
411        &self,
412        sender_id: ConnectionId,
413        receiver_id: ConnectionId,
414        message: T,
415    ) -> Result<()> {
416        let connection = self.connection_state(receiver_id)?;
417        let message_id = connection
418            .next_message_id
419            .fetch_add(1, atomic::Ordering::SeqCst);
420        connection
421            .outgoing_tx
422            .unbounded_send(proto::Message::Envelope(message.into_envelope(
423                message_id,
424                None,
425                Some(sender_id.0),
426            )))?;
427        Ok(())
428    }
429
430    pub fn respond<T: RequestMessage>(
431        &self,
432        receipt: Receipt<T>,
433        response: T::Response,
434    ) -> Result<()> {
435        let connection = self.connection_state(receipt.sender_id)?;
436        let message_id = connection
437            .next_message_id
438            .fetch_add(1, atomic::Ordering::SeqCst);
439        connection
440            .outgoing_tx
441            .unbounded_send(proto::Message::Envelope(response.into_envelope(
442                message_id,
443                Some(receipt.message_id),
444                None,
445            )))?;
446        Ok(())
447    }
448
449    pub fn respond_with_error<T: RequestMessage>(
450        &self,
451        receipt: Receipt<T>,
452        response: proto::Error,
453    ) -> Result<()> {
454        let connection = self.connection_state(receipt.sender_id)?;
455        let message_id = connection
456            .next_message_id
457            .fetch_add(1, atomic::Ordering::SeqCst);
458        connection
459            .outgoing_tx
460            .unbounded_send(proto::Message::Envelope(response.into_envelope(
461                message_id,
462                Some(receipt.message_id),
463                None,
464            )))?;
465        Ok(())
466    }
467
468    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
469        let connections = self.connections.read();
470        let connection = connections
471            .get(&connection_id)
472            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
473        Ok(connection.clone())
474    }
475}
476
477impl Serialize for Peer {
478    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
479    where
480        S: serde::Serializer,
481    {
482        let mut state = serializer.serialize_struct("Peer", 2)?;
483        state.serialize_field("connections", &*self.connections.read())?;
484        state.end()
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491    use crate::TypedEnvelope;
492    use async_tungstenite::tungstenite::Message as WebSocketMessage;
493    use gpui::TestAppContext;
494
495    #[gpui::test(iterations = 50)]
496    async fn test_request_response(cx: &mut TestAppContext) {
497        let executor = cx.foreground();
498
499        // create 2 clients connected to 1 server
500        let server = Peer::new();
501        let client1 = Peer::new();
502        let client2 = Peer::new();
503
504        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
505            Connection::in_memory(cx.background());
506        let (client1_conn_id, io_task1, client1_incoming) = client1
507            .add_test_connection(client1_to_server_conn, cx.background())
508            .await;
509        let (_, io_task2, server_incoming1) = server
510            .add_test_connection(server_to_client_1_conn, cx.background())
511            .await;
512
513        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
514            Connection::in_memory(cx.background());
515        let (client2_conn_id, io_task3, client2_incoming) = client2
516            .add_test_connection(client2_to_server_conn, cx.background())
517            .await;
518        let (_, io_task4, server_incoming2) = server
519            .add_test_connection(server_to_client_2_conn, cx.background())
520            .await;
521
522        executor.spawn(io_task1).detach();
523        executor.spawn(io_task2).detach();
524        executor.spawn(io_task3).detach();
525        executor.spawn(io_task4).detach();
526        executor
527            .spawn(handle_messages(server_incoming1, server.clone()))
528            .detach();
529        executor
530            .spawn(handle_messages(client1_incoming, client1.clone()))
531            .detach();
532        executor
533            .spawn(handle_messages(server_incoming2, server.clone()))
534            .detach();
535        executor
536            .spawn(handle_messages(client2_incoming, client2.clone()))
537            .detach();
538
539        assert_eq!(
540            client1
541                .request(client1_conn_id, proto::Ping {},)
542                .await
543                .unwrap(),
544            proto::Ack {}
545        );
546
547        assert_eq!(
548            client2
549                .request(client2_conn_id, proto::Ping {},)
550                .await
551                .unwrap(),
552            proto::Ack {}
553        );
554
555        assert_eq!(
556            client1
557                .request(client1_conn_id, proto::Test { id: 1 },)
558                .await
559                .unwrap(),
560            proto::Test { id: 1 }
561        );
562
563        assert_eq!(
564            client2
565                .request(client2_conn_id, proto::Test { id: 2 })
566                .await
567                .unwrap(),
568            proto::Test { id: 2 }
569        );
570
571        client1.disconnect(client1_conn_id);
572        client2.disconnect(client1_conn_id);
573
574        async fn handle_messages(
575            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
576            peer: Arc<Peer>,
577        ) -> Result<()> {
578            while let Some(envelope) = messages.next().await {
579                let envelope = envelope.into_any();
580                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
581                    let receipt = envelope.receipt();
582                    peer.respond(receipt, proto::Ack {})?
583                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
584                {
585                    peer.respond(envelope.receipt(), envelope.payload.clone())?
586                } else {
587                    panic!("unknown message type");
588                }
589            }
590
591            Ok(())
592        }
593    }
594
595    #[gpui::test(iterations = 50)]
596    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
597        let executor = cx.foreground();
598        let server = Peer::new();
599        let client = Peer::new();
600
601        let (client_to_server_conn, server_to_client_conn, _kill) =
602            Connection::in_memory(cx.background());
603        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
604            .add_test_connection(client_to_server_conn, cx.background())
605            .await;
606        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
607            .add_test_connection(server_to_client_conn, cx.background())
608            .await;
609
610        executor.spawn(io_task1).detach();
611        executor.spawn(io_task2).detach();
612
613        executor
614            .spawn(async move {
615                let request = server_incoming
616                    .next()
617                    .await
618                    .unwrap()
619                    .into_any()
620                    .downcast::<TypedEnvelope<proto::Ping>>()
621                    .unwrap();
622
623                server
624                    .send(
625                        server_to_client_conn_id,
626                        proto::Error {
627                            message: "message 1".to_string(),
628                        },
629                    )
630                    .unwrap();
631                server
632                    .send(
633                        server_to_client_conn_id,
634                        proto::Error {
635                            message: "message 2".to_string(),
636                        },
637                    )
638                    .unwrap();
639                server.respond(request.receipt(), proto::Ack {}).unwrap();
640
641                // Prevent the connection from being dropped
642                server_incoming.next().await;
643            })
644            .detach();
645
646        let events = Arc::new(Mutex::new(Vec::new()));
647
648        let response = client.request(client_to_server_conn_id, proto::Ping {});
649        let response_task = executor.spawn({
650            let events = events.clone();
651            async move {
652                response.await.unwrap();
653                events.lock().push("response".to_string());
654            }
655        });
656
657        executor
658            .spawn({
659                let events = events.clone();
660                async move {
661                    let incoming1 = client_incoming
662                        .next()
663                        .await
664                        .unwrap()
665                        .into_any()
666                        .downcast::<TypedEnvelope<proto::Error>>()
667                        .unwrap();
668                    events.lock().push(incoming1.payload.message);
669                    let incoming2 = client_incoming
670                        .next()
671                        .await
672                        .unwrap()
673                        .into_any()
674                        .downcast::<TypedEnvelope<proto::Error>>()
675                        .unwrap();
676                    events.lock().push(incoming2.payload.message);
677
678                    // Prevent the connection from being dropped
679                    client_incoming.next().await;
680                }
681            })
682            .detach();
683
684        response_task.await;
685        assert_eq!(
686            &*events.lock(),
687            &[
688                "message 1".to_string(),
689                "message 2".to_string(),
690                "response".to_string()
691            ]
692        );
693    }
694
695    #[gpui::test(iterations = 50)]
696    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
697        let executor = cx.foreground();
698        let server = Peer::new();
699        let client = Peer::new();
700
701        let (client_to_server_conn, server_to_client_conn, _kill) =
702            Connection::in_memory(cx.background());
703        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
704            .add_test_connection(client_to_server_conn, cx.background())
705            .await;
706        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
707            .add_test_connection(server_to_client_conn, cx.background())
708            .await;
709
710        executor.spawn(io_task1).detach();
711        executor.spawn(io_task2).detach();
712
713        executor
714            .spawn(async move {
715                let request1 = server_incoming
716                    .next()
717                    .await
718                    .unwrap()
719                    .into_any()
720                    .downcast::<TypedEnvelope<proto::Ping>>()
721                    .unwrap();
722                let request2 = server_incoming
723                    .next()
724                    .await
725                    .unwrap()
726                    .into_any()
727                    .downcast::<TypedEnvelope<proto::Ping>>()
728                    .unwrap();
729
730                server
731                    .send(
732                        server_to_client_conn_id,
733                        proto::Error {
734                            message: "message 1".to_string(),
735                        },
736                    )
737                    .unwrap();
738                server
739                    .send(
740                        server_to_client_conn_id,
741                        proto::Error {
742                            message: "message 2".to_string(),
743                        },
744                    )
745                    .unwrap();
746                server.respond(request1.receipt(), proto::Ack {}).unwrap();
747                server.respond(request2.receipt(), proto::Ack {}).unwrap();
748
749                // Prevent the connection from being dropped
750                server_incoming.next().await;
751            })
752            .detach();
753
754        let events = Arc::new(Mutex::new(Vec::new()));
755
756        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
757        let request1_task = executor.spawn(request1);
758        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
759        let request2_task = executor.spawn({
760            let events = events.clone();
761            async move {
762                request2.await.unwrap();
763                events.lock().push("response 2".to_string());
764            }
765        });
766
767        executor
768            .spawn({
769                let events = events.clone();
770                async move {
771                    let incoming1 = client_incoming
772                        .next()
773                        .await
774                        .unwrap()
775                        .into_any()
776                        .downcast::<TypedEnvelope<proto::Error>>()
777                        .unwrap();
778                    events.lock().push(incoming1.payload.message);
779                    let incoming2 = client_incoming
780                        .next()
781                        .await
782                        .unwrap()
783                        .into_any()
784                        .downcast::<TypedEnvelope<proto::Error>>()
785                        .unwrap();
786                    events.lock().push(incoming2.payload.message);
787
788                    // Prevent the connection from being dropped
789                    client_incoming.next().await;
790                }
791            })
792            .detach();
793
794        // Allow the request to make some progress before dropping it.
795        cx.background().simulate_random_delay().await;
796        drop(request1_task);
797
798        request2_task.await;
799        assert_eq!(
800            &*events.lock(),
801            &[
802                "message 1".to_string(),
803                "message 2".to_string(),
804                "response 2".to_string()
805            ]
806        );
807    }
808
809    #[gpui::test(iterations = 50)]
810    async fn test_disconnect(cx: &mut TestAppContext) {
811        let executor = cx.foreground();
812
813        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
814
815        let client = Peer::new();
816        let (connection_id, io_handler, mut incoming) = client
817            .add_test_connection(client_conn, cx.background())
818            .await;
819
820        let (io_ended_tx, io_ended_rx) = oneshot::channel();
821        executor
822            .spawn(async move {
823                io_handler.await.ok();
824                io_ended_tx.send(()).unwrap();
825            })
826            .detach();
827
828        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
829        executor
830            .spawn(async move {
831                incoming.next().await;
832                messages_ended_tx.send(()).unwrap();
833            })
834            .detach();
835
836        client.disconnect(connection_id);
837
838        let _ = io_ended_rx.await;
839        let _ = messages_ended_rx.await;
840        assert!(server_conn
841            .send(WebSocketMessage::Binary(vec![]))
842            .await
843            .is_err());
844    }
845
846    #[gpui::test(iterations = 50)]
847    async fn test_io_error(cx: &mut TestAppContext) {
848        let executor = cx.foreground();
849        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
850
851        let client = Peer::new();
852        let (connection_id, io_handler, mut incoming) = client
853            .add_test_connection(client_conn, cx.background())
854            .await;
855        executor.spawn(io_handler).detach();
856        executor
857            .spawn(async move { incoming.next().await })
858            .detach();
859
860        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
861        let _request = server_conn.rx.next().await.unwrap().unwrap();
862
863        drop(server_conn);
864        assert_eq!(
865            response.await.unwrap_err().to_string(),
866            "connection was closed"
867        );
868    }
869}