peer.rs

  1use super::{
  2    proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
  3    Connection,
  4};
  5use anyhow::{anyhow, Context, Result};
  6use collections::HashMap;
  7use futures::{
  8    channel::{mpsc, oneshot},
  9    stream::BoxStream,
 10    FutureExt, SinkExt, StreamExt,
 11};
 12use parking_lot::{Mutex, RwLock};
 13use serde::{ser::SerializeStruct, Serialize};
 14use std::sync::atomic::Ordering::SeqCst;
 15use std::{
 16    fmt,
 17    future::Future,
 18    marker::PhantomData,
 19    sync::{
 20        atomic::{self, AtomicU32},
 21        Arc,
 22    },
 23    time::Duration,
 24};
 25use tracing::instrument;
 26
 27#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Serialize)]
 28pub struct ConnectionId(pub u32);
 29
 30impl fmt::Display for ConnectionId {
 31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 32        self.0.fmt(f)
 33    }
 34}
 35
 36#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 37pub struct PeerId(pub u32);
 38
 39impl fmt::Display for PeerId {
 40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 41        self.0.fmt(f)
 42    }
 43}
 44
 45pub struct Receipt<T> {
 46    pub sender_id: ConnectionId,
 47    pub message_id: u32,
 48    payload_type: PhantomData<T>,
 49}
 50
 51impl<T> Clone for Receipt<T> {
 52    fn clone(&self) -> Self {
 53        Self {
 54            sender_id: self.sender_id,
 55            message_id: self.message_id,
 56            payload_type: PhantomData,
 57        }
 58    }
 59}
 60
 61impl<T> Copy for Receipt<T> {}
 62
 63pub struct TypedEnvelope<T> {
 64    pub sender_id: ConnectionId,
 65    pub original_sender_id: Option<PeerId>,
 66    pub message_id: u32,
 67    pub payload: T,
 68}
 69
 70impl<T> TypedEnvelope<T> {
 71    pub fn original_sender_id(&self) -> Result<PeerId> {
 72        self.original_sender_id
 73            .ok_or_else(|| anyhow!("missing original_sender_id"))
 74    }
 75}
 76
 77impl<T: RequestMessage> TypedEnvelope<T> {
 78    pub fn receipt(&self) -> Receipt<T> {
 79        Receipt {
 80            sender_id: self.sender_id,
 81            message_id: self.message_id,
 82            payload_type: PhantomData,
 83        }
 84    }
 85}
 86
 87pub struct Peer {
 88    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 89    next_connection_id: AtomicU32,
 90}
 91
 92#[derive(Clone, Serialize)]
 93pub struct ConnectionState {
 94    #[serde(skip)]
 95    outgoing_tx: mpsc::UnboundedSender<proto::Message>,
 96    next_message_id: Arc<AtomicU32>,
 97    #[serde(skip)]
 98    response_channels:
 99        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
100}
101
102const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
103const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
104pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
105
106impl Peer {
107    pub fn new() -> Arc<Self> {
108        Arc::new(Self {
109            connections: Default::default(),
110            next_connection_id: Default::default(),
111        })
112    }
113
114    #[instrument(skip_all)]
115    pub async fn add_connection<F, Fut, Out>(
116        self: &Arc<Self>,
117        connection: Connection,
118        create_timer: F,
119    ) -> (
120        ConnectionId,
121        impl Future<Output = anyhow::Result<()>> + Send,
122        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
123    )
124    where
125        F: Send + Fn(Duration) -> Fut,
126        Fut: Send + Future<Output = Out>,
127        Out: Send,
128    {
129        // For outgoing messages, use an unbounded channel so that application code
130        // can always send messages without yielding. For incoming messages, use a
131        // bounded channel so that other peers will receive backpressure if they send
132        // messages faster than this peer can process them.
133        #[cfg(any(test, feature = "test-support"))]
134        const INCOMING_BUFFER_SIZE: usize = 1;
135        #[cfg(not(any(test, feature = "test-support")))]
136        const INCOMING_BUFFER_SIZE: usize = 64;
137        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
138        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
139
140        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
141        let connection_state = ConnectionState {
142            outgoing_tx: outgoing_tx.clone(),
143            next_message_id: Default::default(),
144            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
145        };
146        let mut writer = MessageStream::new(connection.tx);
147        let mut reader = MessageStream::new(connection.rx);
148
149        let this = self.clone();
150        let response_channels = connection_state.response_channels.clone();
151        let handle_io = async move {
152            tracing::debug!(%connection_id, "handle io future: start");
153
154            let _end_connection = util::defer(|| {
155                response_channels.lock().take();
156                this.connections.write().remove(&connection_id);
157                tracing::debug!(%connection_id, "handle io future: end");
158            });
159
160            // Send messages on this frequency so the connection isn't closed.
161            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
162            futures::pin_mut!(keepalive_timer);
163
164            // Disconnect if we don't receive messages at least this frequently.
165            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
166            futures::pin_mut!(receive_timeout);
167
168            loop {
169                tracing::debug!(%connection_id, "outer loop iteration start");
170                let read_message = reader.read().fuse();
171                futures::pin_mut!(read_message);
172
173                loop {
174                    tracing::debug!(%connection_id, "inner loop iteration start");
175                    futures::select_biased! {
176                        outgoing = outgoing_rx.next().fuse() => match outgoing {
177                            Some(outgoing) => {
178                                tracing::debug!(%connection_id, "outgoing rpc message: writing");
179                                futures::select_biased! {
180                                    result = writer.write(outgoing).fuse() => {
181                                        tracing::debug!(%connection_id, "outgoing rpc message: done writing");
182                                        result.context("failed to write RPC message")?;
183                                        tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
184                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
185                                    }
186                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
187                                        tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
188                                        Err(anyhow!("timed out writing message"))?;
189                                    }
190                                }
191                            }
192                            None => {
193                                tracing::debug!(%connection_id, "outgoing rpc message: channel closed");
194                                return Ok(())
195                            },
196                        },
197                        _ = keepalive_timer => {
198                            tracing::debug!(%connection_id, "keepalive interval: pinging");
199                            futures::select_biased! {
200                                result = writer.write(proto::Message::Ping).fuse() => {
201                                    tracing::debug!(%connection_id, "keepalive interval: done pinging");
202                                    result.context("failed to send keepalive")?;
203                                    tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
204                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
205                                }
206                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
207                                    tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
208                                    Err(anyhow!("timed out sending keepalive"))?;
209                                }
210                            }
211                        }
212                        incoming = read_message => {
213                            let incoming = incoming.context("error reading rpc message from socket")?;
214                            tracing::debug!(%connection_id, "incoming rpc message: received");
215                            tracing::debug!(%connection_id, "receive timeout: resetting");
216                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
217                            if let proto::Message::Envelope(incoming) = incoming {
218                                tracing::debug!(%connection_id, "incoming rpc message: processing");
219                                futures::select_biased! {
220                                    result = incoming_tx.send(incoming).fuse() => match result {
221                                        Ok(_) => {
222                                            tracing::debug!(%connection_id, "incoming rpc message: processed");
223                                        }
224                                        Err(_) => {
225                                            tracing::debug!(%connection_id, "incoming rpc message: channel closed");
226                                            return Ok(())
227                                        }
228                                    },
229                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
230                                        tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
231                                        Err(anyhow!("timed out processing incoming message"))?
232                                    }
233                                }
234                            }
235                            break;
236                        },
237                        _ = receive_timeout => {
238                            tracing::debug!(%connection_id, "receive timeout: delay between messages too long");
239                            Err(anyhow!("delay between messages too long"))?
240                        }
241                    }
242                }
243            }
244        };
245
246        let response_channels = connection_state.response_channels.clone();
247        self.connections
248            .write()
249            .insert(connection_id, connection_state);
250
251        let incoming_rx = incoming_rx.filter_map(move |incoming| {
252            let response_channels = response_channels.clone();
253            async move {
254                let message_id = incoming.id;
255                tracing::debug!(?incoming, "incoming message future: start");
256                let _end = util::defer(move || {
257                    tracing::debug!(
258                        %connection_id,
259                        message_id,
260                        "incoming message future: end"
261                    );
262                });
263
264                if let Some(responding_to) = incoming.responding_to {
265                    tracing::debug!(
266                        %connection_id,
267                        message_id,
268                        responding_to,
269                        "incoming response: received"
270                    );
271                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
272                    if let Some(tx) = channel {
273                        let requester_resumed = oneshot::channel();
274                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
275                            tracing::debug!(
276                                %connection_id,
277                                message_id,
278                                responding_to = responding_to,
279                                ?error,
280                                "incoming response: request future dropped",
281                            );
282                        }
283
284                        tracing::debug!(
285                            %connection_id,
286                            message_id,
287                            responding_to,
288                            "incoming response: waiting to resume requester"
289                        );
290                        let _ = requester_resumed.1.await;
291                        tracing::debug!(
292                            %connection_id,
293                            message_id,
294                            responding_to,
295                            "incoming response: requester resumed"
296                        );
297                    } else {
298                        tracing::warn!(
299                            %connection_id,
300                            message_id,
301                            responding_to,
302                            "incoming response: unknown request"
303                        );
304                    }
305
306                    None
307                } else {
308                    tracing::debug!(
309                        %connection_id,
310                        message_id,
311                        "incoming message: received"
312                    );
313                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
314                        tracing::error!(
315                            %connection_id,
316                            message_id,
317                            "unable to construct a typed envelope"
318                        );
319                        None
320                    })
321                }
322            }
323        });
324        (connection_id, handle_io, incoming_rx.boxed())
325    }
326
327    #[cfg(any(test, feature = "test-support"))]
328    pub async fn add_test_connection(
329        self: &Arc<Self>,
330        connection: Connection,
331        executor: Arc<gpui::executor::Background>,
332    ) -> (
333        ConnectionId,
334        impl Future<Output = anyhow::Result<()>> + Send,
335        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
336    ) {
337        let executor = executor.clone();
338        self.add_connection(connection, move |duration| executor.timer(duration))
339            .await
340    }
341
342    pub fn disconnect(&self, connection_id: ConnectionId) {
343        self.connections.write().remove(&connection_id);
344    }
345
346    pub fn reset(&self) {
347        self.connections.write().clear();
348    }
349
350    pub fn request<T: RequestMessage>(
351        &self,
352        receiver_id: ConnectionId,
353        request: T,
354    ) -> impl Future<Output = Result<T::Response>> {
355        self.request_internal(None, receiver_id, request)
356    }
357
358    pub fn forward_request<T: RequestMessage>(
359        &self,
360        sender_id: ConnectionId,
361        receiver_id: ConnectionId,
362        request: T,
363    ) -> impl Future<Output = Result<T::Response>> {
364        self.request_internal(Some(sender_id), receiver_id, request)
365    }
366
367    pub fn request_internal<T: RequestMessage>(
368        &self,
369        original_sender_id: Option<ConnectionId>,
370        receiver_id: ConnectionId,
371        request: T,
372    ) -> impl Future<Output = Result<T::Response>> {
373        let (tx, rx) = oneshot::channel();
374        let send = self.connection_state(receiver_id).and_then(|connection| {
375            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
376            connection
377                .response_channels
378                .lock()
379                .as_mut()
380                .ok_or_else(|| anyhow!("connection was closed"))?
381                .insert(message_id, tx);
382            connection
383                .outgoing_tx
384                .unbounded_send(proto::Message::Envelope(request.into_envelope(
385                    message_id,
386                    None,
387                    original_sender_id.map(|id| id.0),
388                )))
389                .map_err(|_| anyhow!("connection was closed"))?;
390            Ok(())
391        });
392        async move {
393            send?;
394            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
395            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
396                Err(anyhow!("RPC request failed - {}", error.message))
397            } else {
398                T::Response::from_envelope(response)
399                    .ok_or_else(|| anyhow!("received response of the wrong type"))
400            }
401        }
402    }
403
404    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
405        let connection = self.connection_state(receiver_id)?;
406        let message_id = connection
407            .next_message_id
408            .fetch_add(1, atomic::Ordering::SeqCst);
409        connection
410            .outgoing_tx
411            .unbounded_send(proto::Message::Envelope(
412                message.into_envelope(message_id, None, None),
413            ))?;
414        Ok(())
415    }
416
417    pub fn forward_send<T: EnvelopedMessage>(
418        &self,
419        sender_id: ConnectionId,
420        receiver_id: ConnectionId,
421        message: T,
422    ) -> Result<()> {
423        let connection = self.connection_state(receiver_id)?;
424        let message_id = connection
425            .next_message_id
426            .fetch_add(1, atomic::Ordering::SeqCst);
427        connection
428            .outgoing_tx
429            .unbounded_send(proto::Message::Envelope(message.into_envelope(
430                message_id,
431                None,
432                Some(sender_id.0),
433            )))?;
434        Ok(())
435    }
436
437    pub fn respond<T: RequestMessage>(
438        &self,
439        receipt: Receipt<T>,
440        response: T::Response,
441    ) -> Result<()> {
442        let connection = self.connection_state(receipt.sender_id)?;
443        let message_id = connection
444            .next_message_id
445            .fetch_add(1, atomic::Ordering::SeqCst);
446        connection
447            .outgoing_tx
448            .unbounded_send(proto::Message::Envelope(response.into_envelope(
449                message_id,
450                Some(receipt.message_id),
451                None,
452            )))?;
453        Ok(())
454    }
455
456    pub fn respond_with_error<T: RequestMessage>(
457        &self,
458        receipt: Receipt<T>,
459        response: proto::Error,
460    ) -> Result<()> {
461        let connection = self.connection_state(receipt.sender_id)?;
462        let message_id = connection
463            .next_message_id
464            .fetch_add(1, atomic::Ordering::SeqCst);
465        connection
466            .outgoing_tx
467            .unbounded_send(proto::Message::Envelope(response.into_envelope(
468                message_id,
469                Some(receipt.message_id),
470                None,
471            )))?;
472        Ok(())
473    }
474
475    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
476        let connections = self.connections.read();
477        let connection = connections
478            .get(&connection_id)
479            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
480        Ok(connection.clone())
481    }
482}
483
484impl Serialize for Peer {
485    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
486    where
487        S: serde::Serializer,
488    {
489        let mut state = serializer.serialize_struct("Peer", 2)?;
490        state.serialize_field("connections", &*self.connections.read())?;
491        state.end()
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::TypedEnvelope;
499    use async_tungstenite::tungstenite::Message as WebSocketMessage;
500    use gpui::TestAppContext;
501
502    #[ctor::ctor]
503    fn init_logger() {
504        if std::env::var("RUST_LOG").is_ok() {
505            env_logger::init();
506        }
507    }
508
509    #[gpui::test(iterations = 50)]
510    async fn test_request_response(cx: &mut TestAppContext) {
511        let executor = cx.foreground();
512
513        // create 2 clients connected to 1 server
514        let server = Peer::new();
515        let client1 = Peer::new();
516        let client2 = Peer::new();
517
518        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
519            Connection::in_memory(cx.background());
520        let (client1_conn_id, io_task1, client1_incoming) = client1
521            .add_test_connection(client1_to_server_conn, cx.background())
522            .await;
523        let (_, io_task2, server_incoming1) = server
524            .add_test_connection(server_to_client_1_conn, cx.background())
525            .await;
526
527        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
528            Connection::in_memory(cx.background());
529        let (client2_conn_id, io_task3, client2_incoming) = client2
530            .add_test_connection(client2_to_server_conn, cx.background())
531            .await;
532        let (_, io_task4, server_incoming2) = server
533            .add_test_connection(server_to_client_2_conn, cx.background())
534            .await;
535
536        executor.spawn(io_task1).detach();
537        executor.spawn(io_task2).detach();
538        executor.spawn(io_task3).detach();
539        executor.spawn(io_task4).detach();
540        executor
541            .spawn(handle_messages(server_incoming1, server.clone()))
542            .detach();
543        executor
544            .spawn(handle_messages(client1_incoming, client1.clone()))
545            .detach();
546        executor
547            .spawn(handle_messages(server_incoming2, server.clone()))
548            .detach();
549        executor
550            .spawn(handle_messages(client2_incoming, client2.clone()))
551            .detach();
552
553        assert_eq!(
554            client1
555                .request(client1_conn_id, proto::Ping {},)
556                .await
557                .unwrap(),
558            proto::Ack {}
559        );
560
561        assert_eq!(
562            client2
563                .request(client2_conn_id, proto::Ping {},)
564                .await
565                .unwrap(),
566            proto::Ack {}
567        );
568
569        assert_eq!(
570            client1
571                .request(client1_conn_id, proto::Test { id: 1 },)
572                .await
573                .unwrap(),
574            proto::Test { id: 1 }
575        );
576
577        assert_eq!(
578            client2
579                .request(client2_conn_id, proto::Test { id: 2 })
580                .await
581                .unwrap(),
582            proto::Test { id: 2 }
583        );
584
585        client1.disconnect(client1_conn_id);
586        client2.disconnect(client1_conn_id);
587
588        async fn handle_messages(
589            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
590            peer: Arc<Peer>,
591        ) -> Result<()> {
592            while let Some(envelope) = messages.next().await {
593                let envelope = envelope.into_any();
594                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
595                    let receipt = envelope.receipt();
596                    peer.respond(receipt, proto::Ack {})?
597                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
598                {
599                    peer.respond(envelope.receipt(), envelope.payload.clone())?
600                } else {
601                    panic!("unknown message type");
602                }
603            }
604
605            Ok(())
606        }
607    }
608
609    #[gpui::test(iterations = 50)]
610    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
611        let executor = cx.foreground();
612        let server = Peer::new();
613        let client = Peer::new();
614
615        let (client_to_server_conn, server_to_client_conn, _kill) =
616            Connection::in_memory(cx.background());
617        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
618            .add_test_connection(client_to_server_conn, cx.background())
619            .await;
620        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
621            .add_test_connection(server_to_client_conn, cx.background())
622            .await;
623
624        executor.spawn(io_task1).detach();
625        executor.spawn(io_task2).detach();
626
627        executor
628            .spawn(async move {
629                let request = server_incoming
630                    .next()
631                    .await
632                    .unwrap()
633                    .into_any()
634                    .downcast::<TypedEnvelope<proto::Ping>>()
635                    .unwrap();
636
637                server
638                    .send(
639                        server_to_client_conn_id,
640                        proto::Error {
641                            message: "message 1".to_string(),
642                        },
643                    )
644                    .unwrap();
645                server
646                    .send(
647                        server_to_client_conn_id,
648                        proto::Error {
649                            message: "message 2".to_string(),
650                        },
651                    )
652                    .unwrap();
653                server.respond(request.receipt(), proto::Ack {}).unwrap();
654
655                // Prevent the connection from being dropped
656                server_incoming.next().await;
657            })
658            .detach();
659
660        let events = Arc::new(Mutex::new(Vec::new()));
661
662        let response = client.request(client_to_server_conn_id, proto::Ping {});
663        let response_task = executor.spawn({
664            let events = events.clone();
665            async move {
666                response.await.unwrap();
667                events.lock().push("response".to_string());
668            }
669        });
670
671        executor
672            .spawn({
673                let events = events.clone();
674                async move {
675                    let incoming1 = client_incoming
676                        .next()
677                        .await
678                        .unwrap()
679                        .into_any()
680                        .downcast::<TypedEnvelope<proto::Error>>()
681                        .unwrap();
682                    events.lock().push(incoming1.payload.message);
683                    let incoming2 = client_incoming
684                        .next()
685                        .await
686                        .unwrap()
687                        .into_any()
688                        .downcast::<TypedEnvelope<proto::Error>>()
689                        .unwrap();
690                    events.lock().push(incoming2.payload.message);
691
692                    // Prevent the connection from being dropped
693                    client_incoming.next().await;
694                }
695            })
696            .detach();
697
698        response_task.await;
699        assert_eq!(
700            &*events.lock(),
701            &[
702                "message 1".to_string(),
703                "message 2".to_string(),
704                "response".to_string()
705            ]
706        );
707    }
708
709    #[gpui::test(iterations = 50)]
710    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
711        let executor = cx.foreground();
712        let server = Peer::new();
713        let client = Peer::new();
714
715        let (client_to_server_conn, server_to_client_conn, _kill) =
716            Connection::in_memory(cx.background());
717        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
718            .add_test_connection(client_to_server_conn, cx.background())
719            .await;
720        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
721            .add_test_connection(server_to_client_conn, cx.background())
722            .await;
723
724        executor.spawn(io_task1).detach();
725        executor.spawn(io_task2).detach();
726
727        executor
728            .spawn(async move {
729                let request1 = server_incoming
730                    .next()
731                    .await
732                    .unwrap()
733                    .into_any()
734                    .downcast::<TypedEnvelope<proto::Ping>>()
735                    .unwrap();
736                let request2 = server_incoming
737                    .next()
738                    .await
739                    .unwrap()
740                    .into_any()
741                    .downcast::<TypedEnvelope<proto::Ping>>()
742                    .unwrap();
743
744                server
745                    .send(
746                        server_to_client_conn_id,
747                        proto::Error {
748                            message: "message 1".to_string(),
749                        },
750                    )
751                    .unwrap();
752                server
753                    .send(
754                        server_to_client_conn_id,
755                        proto::Error {
756                            message: "message 2".to_string(),
757                        },
758                    )
759                    .unwrap();
760                server.respond(request1.receipt(), proto::Ack {}).unwrap();
761                server.respond(request2.receipt(), proto::Ack {}).unwrap();
762
763                // Prevent the connection from being dropped
764                server_incoming.next().await;
765            })
766            .detach();
767
768        let events = Arc::new(Mutex::new(Vec::new()));
769
770        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
771        let request1_task = executor.spawn(request1);
772        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
773        let request2_task = executor.spawn({
774            let events = events.clone();
775            async move {
776                request2.await.unwrap();
777                events.lock().push("response 2".to_string());
778            }
779        });
780
781        executor
782            .spawn({
783                let events = events.clone();
784                async move {
785                    let incoming1 = client_incoming
786                        .next()
787                        .await
788                        .unwrap()
789                        .into_any()
790                        .downcast::<TypedEnvelope<proto::Error>>()
791                        .unwrap();
792                    events.lock().push(incoming1.payload.message);
793                    let incoming2 = client_incoming
794                        .next()
795                        .await
796                        .unwrap()
797                        .into_any()
798                        .downcast::<TypedEnvelope<proto::Error>>()
799                        .unwrap();
800                    events.lock().push(incoming2.payload.message);
801
802                    // Prevent the connection from being dropped
803                    client_incoming.next().await;
804                }
805            })
806            .detach();
807
808        // Allow the request to make some progress before dropping it.
809        cx.background().simulate_random_delay().await;
810        drop(request1_task);
811
812        request2_task.await;
813        assert_eq!(
814            &*events.lock(),
815            &[
816                "message 1".to_string(),
817                "message 2".to_string(),
818                "response 2".to_string()
819            ]
820        );
821    }
822
823    #[gpui::test(iterations = 50)]
824    async fn test_disconnect(cx: &mut TestAppContext) {
825        let executor = cx.foreground();
826
827        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
828
829        let client = Peer::new();
830        let (connection_id, io_handler, mut incoming) = client
831            .add_test_connection(client_conn, cx.background())
832            .await;
833
834        let (io_ended_tx, io_ended_rx) = oneshot::channel();
835        executor
836            .spawn(async move {
837                io_handler.await.ok();
838                io_ended_tx.send(()).unwrap();
839            })
840            .detach();
841
842        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
843        executor
844            .spawn(async move {
845                incoming.next().await;
846                messages_ended_tx.send(()).unwrap();
847            })
848            .detach();
849
850        client.disconnect(connection_id);
851
852        let _ = io_ended_rx.await;
853        let _ = messages_ended_rx.await;
854        assert!(server_conn
855            .send(WebSocketMessage::Binary(vec![]))
856            .await
857            .is_err());
858    }
859
860    #[gpui::test(iterations = 50)]
861    async fn test_io_error(cx: &mut TestAppContext) {
862        let executor = cx.foreground();
863        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
864
865        let client = Peer::new();
866        let (connection_id, io_handler, mut incoming) = client
867            .add_test_connection(client_conn, cx.background())
868            .await;
869        executor.spawn(io_handler).detach();
870        executor
871            .spawn(async move { incoming.next().await })
872            .detach();
873
874        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
875        let _request = server_conn.rx.next().await.unwrap().unwrap();
876
877        drop(server_conn);
878        assert_eq!(
879            response.await.unwrap_err().to_string(),
880            "connection was closed"
881        );
882    }
883}