peer.rs

  1use crate::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
  2
  3use super::{
  4    proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
  5    Connection,
  6};
  7use anyhow::{anyhow, Context, Result};
  8use collections::HashMap;
  9use futures::{
 10    channel::{mpsc, oneshot},
 11    stream::BoxStream,
 12    FutureExt, SinkExt, StreamExt, TryFutureExt,
 13};
 14use parking_lot::{Mutex, RwLock};
 15use serde::{ser::SerializeStruct, Serialize};
 16use std::{fmt, sync::atomic::Ordering::SeqCst, time::Instant};
 17use std::{
 18    future::Future,
 19    marker::PhantomData,
 20    sync::{
 21        atomic::{self, AtomicU32},
 22        Arc,
 23    },
 24    time::Duration,
 25};
 26use tracing::instrument;
 27
 28#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
 29pub struct ConnectionId {
 30    pub owner_id: u32,
 31    pub id: u32,
 32}
 33
 34impl Into<PeerId> for ConnectionId {
 35    fn into(self) -> PeerId {
 36        PeerId {
 37            owner_id: self.owner_id,
 38            id: self.id,
 39        }
 40    }
 41}
 42
 43impl From<PeerId> for ConnectionId {
 44    fn from(peer_id: PeerId) -> Self {
 45        Self {
 46            owner_id: peer_id.owner_id,
 47            id: peer_id.id,
 48        }
 49    }
 50}
 51
 52impl fmt::Display for ConnectionId {
 53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 54        write!(f, "{}/{}", self.owner_id, self.id)
 55    }
 56}
 57
 58pub struct Receipt<T> {
 59    pub sender_id: ConnectionId,
 60    pub message_id: u32,
 61    payload_type: PhantomData<T>,
 62}
 63
 64impl<T> Clone for Receipt<T> {
 65    fn clone(&self) -> Self {
 66        Self {
 67            sender_id: self.sender_id,
 68            message_id: self.message_id,
 69            payload_type: PhantomData,
 70        }
 71    }
 72}
 73
 74impl<T> Copy for Receipt<T> {}
 75
 76#[derive(Clone, Debug)]
 77pub struct TypedEnvelope<T> {
 78    pub sender_id: ConnectionId,
 79    pub original_sender_id: Option<PeerId>,
 80    pub message_id: u32,
 81    pub payload: T,
 82    pub received_at: Instant,
 83}
 84
 85impl<T> TypedEnvelope<T> {
 86    pub fn original_sender_id(&self) -> Result<PeerId> {
 87        self.original_sender_id
 88            .ok_or_else(|| anyhow!("missing original_sender_id"))
 89    }
 90}
 91
 92impl<T: RequestMessage> TypedEnvelope<T> {
 93    pub fn receipt(&self) -> Receipt<T> {
 94        Receipt {
 95            sender_id: self.sender_id,
 96            message_id: self.message_id,
 97            payload_type: PhantomData,
 98        }
 99    }
100}
101
102pub struct Peer {
103    epoch: AtomicU32,
104    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
105    next_connection_id: AtomicU32,
106}
107
108#[derive(Clone, Serialize)]
109pub struct ConnectionState {
110    #[serde(skip)]
111    outgoing_tx: mpsc::UnboundedSender<proto::Message>,
112    next_message_id: Arc<AtomicU32>,
113    #[allow(clippy::type_complexity)]
114    #[serde(skip)]
115    response_channels: Arc<
116        Mutex<
117            Option<
118                HashMap<
119                    u32,
120                    oneshot::Sender<(proto::Envelope, std::time::Instant, oneshot::Sender<()>)>,
121                >,
122            >,
123        >,
124    >,
125}
126
127const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
128const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
129pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
130
131impl Peer {
132    pub fn new(epoch: u32) -> Arc<Self> {
133        Arc::new(Self {
134            epoch: AtomicU32::new(epoch),
135            connections: Default::default(),
136            next_connection_id: Default::default(),
137        })
138    }
139
140    pub fn epoch(&self) -> u32 {
141        self.epoch.load(SeqCst)
142    }
143
144    #[instrument(skip_all)]
145    pub fn add_connection<F, Fut, Out>(
146        self: &Arc<Self>,
147        connection: Connection,
148        create_timer: F,
149    ) -> (
150        ConnectionId,
151        impl Future<Output = anyhow::Result<()>> + Send,
152        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
153    )
154    where
155        F: Send + Fn(Duration) -> Fut,
156        Fut: Send + Future<Output = Out>,
157        Out: Send,
158    {
159        // For outgoing messages, use an unbounded channel so that application code
160        // can always send messages without yielding. For incoming messages, use a
161        // bounded channel so that other peers will receive backpressure if they send
162        // messages faster than this peer can process them.
163        #[cfg(any(test, feature = "test-support"))]
164        const INCOMING_BUFFER_SIZE: usize = 1;
165        #[cfg(not(any(test, feature = "test-support")))]
166        const INCOMING_BUFFER_SIZE: usize = 256;
167        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
168        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
169
170        let connection_id = ConnectionId {
171            owner_id: self.epoch.load(SeqCst),
172            id: self.next_connection_id.fetch_add(1, SeqCst),
173        };
174        let connection_state = ConnectionState {
175            outgoing_tx,
176            next_message_id: Default::default(),
177            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
178        };
179        let mut writer = MessageStream::new(connection.tx);
180        let mut reader = MessageStream::new(connection.rx);
181
182        let this = self.clone();
183        let response_channels = connection_state.response_channels.clone();
184        let handle_io = async move {
185            tracing::trace!(%connection_id, "handle io future: start");
186
187            let _end_connection = util::defer(|| {
188                response_channels.lock().take();
189                this.connections.write().remove(&connection_id);
190                tracing::trace!(%connection_id, "handle io future: end");
191            });
192
193            // Send messages on this frequency so the connection isn't closed.
194            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
195            futures::pin_mut!(keepalive_timer);
196
197            // Disconnect if we don't receive messages at least this frequently.
198            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
199            futures::pin_mut!(receive_timeout);
200
201            loop {
202                tracing::trace!(%connection_id, "outer loop iteration start");
203                let read_message = reader.read().fuse();
204                futures::pin_mut!(read_message);
205
206                loop {
207                    tracing::trace!(%connection_id, "inner loop iteration start");
208                    futures::select_biased! {
209                        outgoing = outgoing_rx.next().fuse() => match outgoing {
210                            Some(outgoing) => {
211                                tracing::trace!(%connection_id, "outgoing rpc message: writing");
212                                futures::select_biased! {
213                                    result = writer.write(outgoing).fuse() => {
214                                        tracing::trace!(%connection_id, "outgoing rpc message: done writing");
215                                        result.context("failed to write RPC message")?;
216                                        tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
217                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
218                                    }
219                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
220                                        tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
221                                        Err(anyhow!("timed out writing message"))?;
222                                    }
223                                }
224                            }
225                            None => {
226                                tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
227                                return Ok(())
228                            },
229                        },
230                        _ = keepalive_timer => {
231                            tracing::trace!(%connection_id, "keepalive interval: pinging");
232                            futures::select_biased! {
233                                result = writer.write(proto::Message::Ping).fuse() => {
234                                    tracing::trace!(%connection_id, "keepalive interval: done pinging");
235                                    result.context("failed to send keepalive")?;
236                                    tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
237                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
238                                }
239                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
240                                    tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
241                                    Err(anyhow!("timed out sending keepalive"))?;
242                                }
243                            }
244                        }
245                        incoming = read_message => {
246                            let incoming = incoming.context("error reading rpc message from socket")?;
247                            tracing::trace!(%connection_id, "incoming rpc message: received");
248                            tracing::trace!(%connection_id, "receive timeout: resetting");
249                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
250                            if let (proto::Message::Envelope(incoming), received_at) = incoming {
251                                tracing::trace!(%connection_id, "incoming rpc message: processing");
252                                futures::select_biased! {
253                                    result = incoming_tx.send((incoming, received_at)).fuse() => match result {
254                                        Ok(_) => {
255                                            tracing::trace!(%connection_id, "incoming rpc message: processed");
256                                        }
257                                        Err(_) => {
258                                            tracing::trace!(%connection_id, "incoming rpc message: channel closed");
259                                            return Ok(())
260                                        }
261                                    },
262                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
263                                        tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
264                                        Err(anyhow!("timed out processing incoming message"))?
265                                    }
266                                }
267                            }
268                            break;
269                        },
270                        _ = receive_timeout => {
271                            tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
272                            Err(anyhow!("delay between messages too long"))?
273                        }
274                    }
275                }
276            }
277        };
278
279        let response_channels = connection_state.response_channels.clone();
280        self.connections
281            .write()
282            .insert(connection_id, connection_state);
283
284        let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
285            let response_channels = response_channels.clone();
286            async move {
287                let message_id = incoming.id;
288                tracing::trace!(?incoming, "incoming message future: start");
289                let _end = util::defer(move || {
290                    tracing::trace!(%connection_id, message_id, "incoming message future: end");
291                });
292
293                if let Some(responding_to) = incoming.responding_to {
294                    tracing::trace!(
295                        %connection_id,
296                        message_id,
297                        responding_to,
298                        "incoming response: received"
299                    );
300                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
301                    if let Some(tx) = channel {
302                        let requester_resumed = oneshot::channel();
303                        if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
304                            tracing::trace!(
305                                %connection_id,
306                                message_id,
307                                responding_to = responding_to,
308                                ?error,
309                                "incoming response: request future dropped",
310                            );
311                        }
312
313                        tracing::trace!(
314                            %connection_id,
315                            message_id,
316                            responding_to,
317                            "incoming response: waiting to resume requester"
318                        );
319                        let _ = requester_resumed.1.await;
320                        tracing::trace!(
321                            %connection_id,
322                            message_id,
323                            responding_to,
324                            "incoming response: requester resumed"
325                        );
326                    } else {
327                        let message_type =
328                            proto::build_typed_envelope(connection_id, received_at, incoming)
329                                .map(|p| p.payload_type_name());
330                        tracing::warn!(
331                            %connection_id,
332                            message_id,
333                            responding_to,
334                            message_type,
335                            "incoming response: unknown request"
336                        );
337                    }
338
339                    None
340                } else {
341                    tracing::trace!(%connection_id, message_id, "incoming message: received");
342                    proto::build_typed_envelope(connection_id, received_at, incoming).or_else(
343                        || {
344                            tracing::error!(
345                                %connection_id,
346                                message_id,
347                                "unable to construct a typed envelope"
348                            );
349                            None
350                        },
351                    )
352                }
353            }
354        });
355        (connection_id, handle_io, incoming_rx.boxed())
356    }
357
358    #[cfg(any(test, feature = "test-support"))]
359    pub fn add_test_connection(
360        self: &Arc<Self>,
361        connection: Connection,
362        executor: gpui::BackgroundExecutor,
363    ) -> (
364        ConnectionId,
365        impl Future<Output = anyhow::Result<()>> + Send,
366        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
367    ) {
368        let executor = executor.clone();
369        self.add_connection(connection, move |duration| executor.timer(duration))
370    }
371
372    pub fn disconnect(&self, connection_id: ConnectionId) {
373        self.connections.write().remove(&connection_id);
374    }
375
376    #[cfg(any(test, feature = "test-support"))]
377    pub fn reset(&self, epoch: u32) {
378        self.next_connection_id.store(0, SeqCst);
379        self.epoch.store(epoch, SeqCst);
380    }
381
382    pub fn teardown(&self) {
383        self.connections.write().clear();
384    }
385
386    pub fn request<T: RequestMessage>(
387        &self,
388        receiver_id: ConnectionId,
389        request: T,
390    ) -> impl Future<Output = Result<T::Response>> {
391        self.request_internal(None, receiver_id, request)
392            .map_ok(|envelope| envelope.payload)
393    }
394
395    pub fn request_envelope<T: RequestMessage>(
396        &self,
397        receiver_id: ConnectionId,
398        request: T,
399    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
400        self.request_internal(None, receiver_id, request)
401    }
402
403    pub fn forward_request<T: RequestMessage>(
404        &self,
405        sender_id: ConnectionId,
406        receiver_id: ConnectionId,
407        request: T,
408    ) -> impl Future<Output = Result<T::Response>> {
409        self.request_internal(Some(sender_id), receiver_id, request)
410            .map_ok(|envelope| envelope.payload)
411    }
412
413    pub fn request_internal<T: RequestMessage>(
414        &self,
415        original_sender_id: Option<ConnectionId>,
416        receiver_id: ConnectionId,
417        request: T,
418    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
419        let (tx, rx) = oneshot::channel();
420        let send = self.connection_state(receiver_id).and_then(|connection| {
421            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
422            connection
423                .response_channels
424                .lock()
425                .as_mut()
426                .ok_or_else(|| anyhow!("connection was closed"))?
427                .insert(message_id, tx);
428            connection
429                .outgoing_tx
430                .unbounded_send(proto::Message::Envelope(request.into_envelope(
431                    message_id,
432                    None,
433                    original_sender_id.map(Into::into),
434                )))
435                .map_err(|_| anyhow!("connection was closed"))?;
436            Ok(())
437        });
438        async move {
439            send?;
440            let (response, received_at, _barrier) =
441                rx.await.map_err(|_| anyhow!("connection was closed"))?;
442
443            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
444                Err(RpcError::from_proto(&error, T::NAME))
445            } else {
446                Ok(TypedEnvelope {
447                    message_id: response.id,
448                    sender_id: receiver_id,
449                    original_sender_id: response.original_sender_id,
450                    payload: T::Response::from_envelope(response)
451                        .ok_or_else(|| anyhow!("received response of the wrong type"))?,
452                    received_at,
453                })
454            }
455        }
456    }
457
458    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
459        let connection = self.connection_state(receiver_id)?;
460        let message_id = connection
461            .next_message_id
462            .fetch_add(1, atomic::Ordering::SeqCst);
463        connection
464            .outgoing_tx
465            .unbounded_send(proto::Message::Envelope(
466                message.into_envelope(message_id, None, None),
467            ))?;
468        Ok(())
469    }
470
471    pub fn forward_send<T: EnvelopedMessage>(
472        &self,
473        sender_id: ConnectionId,
474        receiver_id: ConnectionId,
475        message: T,
476    ) -> Result<()> {
477        let connection = self.connection_state(receiver_id)?;
478        let message_id = connection
479            .next_message_id
480            .fetch_add(1, atomic::Ordering::SeqCst);
481        connection
482            .outgoing_tx
483            .unbounded_send(proto::Message::Envelope(message.into_envelope(
484                message_id,
485                None,
486                Some(sender_id.into()),
487            )))?;
488        Ok(())
489    }
490
491    pub fn respond<T: RequestMessage>(
492        &self,
493        receipt: Receipt<T>,
494        response: T::Response,
495    ) -> Result<()> {
496        let connection = self.connection_state(receipt.sender_id)?;
497        let message_id = connection
498            .next_message_id
499            .fetch_add(1, atomic::Ordering::SeqCst);
500        connection
501            .outgoing_tx
502            .unbounded_send(proto::Message::Envelope(response.into_envelope(
503                message_id,
504                Some(receipt.message_id),
505                None,
506            )))?;
507        Ok(())
508    }
509
510    pub fn respond_with_error<T: RequestMessage>(
511        &self,
512        receipt: Receipt<T>,
513        response: proto::Error,
514    ) -> Result<()> {
515        let connection = self.connection_state(receipt.sender_id)?;
516        let message_id = connection
517            .next_message_id
518            .fetch_add(1, atomic::Ordering::SeqCst);
519        connection
520            .outgoing_tx
521            .unbounded_send(proto::Message::Envelope(response.into_envelope(
522                message_id,
523                Some(receipt.message_id),
524                None,
525            )))?;
526        Ok(())
527    }
528
529    pub fn respond_with_unhandled_message(
530        &self,
531        envelope: Box<dyn AnyTypedEnvelope>,
532    ) -> Result<()> {
533        let connection = self.connection_state(envelope.sender_id())?;
534        let response = ErrorCode::Internal
535            .message(format!(
536                "message {} was not handled",
537                envelope.payload_type_name()
538            ))
539            .to_proto();
540        let message_id = connection
541            .next_message_id
542            .fetch_add(1, atomic::Ordering::SeqCst);
543        connection
544            .outgoing_tx
545            .unbounded_send(proto::Message::Envelope(response.into_envelope(
546                message_id,
547                Some(envelope.message_id()),
548                None,
549            )))?;
550        Ok(())
551    }
552
553    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
554        let connections = self.connections.read();
555        let connection = connections
556            .get(&connection_id)
557            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
558        Ok(connection.clone())
559    }
560}
561
562impl Serialize for Peer {
563    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
564    where
565        S: serde::Serializer,
566    {
567        let mut state = serializer.serialize_struct("Peer", 2)?;
568        state.serialize_field("connections", &*self.connections.read())?;
569        state.end()
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576    use crate::TypedEnvelope;
577    use async_tungstenite::tungstenite::Message as WebSocketMessage;
578    use gpui::TestAppContext;
579
580    fn init_logger() {
581        if std::env::var("RUST_LOG").is_ok() {
582            env_logger::init();
583        }
584    }
585
586    #[gpui::test(iterations = 50)]
587    async fn test_request_response(cx: &mut TestAppContext) {
588        init_logger();
589
590        let executor = cx.executor();
591
592        // create 2 clients connected to 1 server
593        let server = Peer::new(0);
594        let client1 = Peer::new(0);
595        let client2 = Peer::new(0);
596
597        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
598            Connection::in_memory(cx.executor());
599        let (client1_conn_id, io_task1, client1_incoming) =
600            client1.add_test_connection(client1_to_server_conn, cx.executor());
601        let (_, io_task2, server_incoming1) =
602            server.add_test_connection(server_to_client_1_conn, cx.executor());
603
604        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
605            Connection::in_memory(cx.executor());
606        let (client2_conn_id, io_task3, client2_incoming) =
607            client2.add_test_connection(client2_to_server_conn, cx.executor());
608        let (_, io_task4, server_incoming2) =
609            server.add_test_connection(server_to_client_2_conn, cx.executor());
610
611        executor.spawn(io_task1).detach();
612        executor.spawn(io_task2).detach();
613        executor.spawn(io_task3).detach();
614        executor.spawn(io_task4).detach();
615        executor
616            .spawn(handle_messages(server_incoming1, server.clone()))
617            .detach();
618        executor
619            .spawn(handle_messages(client1_incoming, client1.clone()))
620            .detach();
621        executor
622            .spawn(handle_messages(server_incoming2, server.clone()))
623            .detach();
624        executor
625            .spawn(handle_messages(client2_incoming, client2.clone()))
626            .detach();
627
628        assert_eq!(
629            client1
630                .request(client1_conn_id, proto::Ping {},)
631                .await
632                .unwrap(),
633            proto::Ack {}
634        );
635
636        assert_eq!(
637            client2
638                .request(client2_conn_id, proto::Ping {},)
639                .await
640                .unwrap(),
641            proto::Ack {}
642        );
643
644        assert_eq!(
645            client1
646                .request(client1_conn_id, proto::Test { id: 1 },)
647                .await
648                .unwrap(),
649            proto::Test { id: 1 }
650        );
651
652        assert_eq!(
653            client2
654                .request(client2_conn_id, proto::Test { id: 2 })
655                .await
656                .unwrap(),
657            proto::Test { id: 2 }
658        );
659
660        client1.disconnect(client1_conn_id);
661        client2.disconnect(client1_conn_id);
662
663        async fn handle_messages(
664            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
665            peer: Arc<Peer>,
666        ) -> Result<()> {
667            while let Some(envelope) = messages.next().await {
668                let envelope = envelope.into_any();
669                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
670                    let receipt = envelope.receipt();
671                    peer.respond(receipt, proto::Ack {})?
672                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
673                {
674                    peer.respond(envelope.receipt(), envelope.payload.clone())?
675                } else {
676                    panic!("unknown message type");
677                }
678            }
679
680            Ok(())
681        }
682    }
683
684    #[gpui::test(iterations = 50)]
685    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
686        let executor = cx.executor();
687        let server = Peer::new(0);
688        let client = Peer::new(0);
689
690        let (client_to_server_conn, server_to_client_conn, _kill) =
691            Connection::in_memory(executor.clone());
692        let (client_to_server_conn_id, io_task1, mut client_incoming) =
693            client.add_test_connection(client_to_server_conn, executor.clone());
694
695        let (server_to_client_conn_id, io_task2, mut server_incoming) =
696            server.add_test_connection(server_to_client_conn, executor.clone());
697
698        executor.spawn(io_task1).detach();
699        executor.spawn(io_task2).detach();
700
701        executor
702            .spawn(async move {
703                let future = server_incoming.next().await;
704                let request = future
705                    .unwrap()
706                    .into_any()
707                    .downcast::<TypedEnvelope<proto::Ping>>()
708                    .unwrap();
709
710                server
711                    .send(
712                        server_to_client_conn_id,
713                        ErrorCode::Internal
714                            .message("message 1".to_string())
715                            .to_proto(),
716                    )
717                    .unwrap();
718                server
719                    .send(
720                        server_to_client_conn_id,
721                        ErrorCode::Internal
722                            .message("message 2".to_string())
723                            .to_proto(),
724                    )
725                    .unwrap();
726                server.respond(request.receipt(), proto::Ack {}).unwrap();
727
728                // Prevent the connection from being dropped
729                server_incoming.next().await;
730            })
731            .detach();
732
733        let events = Arc::new(Mutex::new(Vec::new()));
734
735        let response = client.request(client_to_server_conn_id, proto::Ping {});
736        let response_task = executor.spawn({
737            let events = events.clone();
738            async move {
739                response.await.unwrap();
740                events.lock().push("response".to_string());
741            }
742        });
743
744        executor
745            .spawn({
746                let events = events.clone();
747                async move {
748                    let incoming1 = client_incoming
749                        .next()
750                        .await
751                        .unwrap()
752                        .into_any()
753                        .downcast::<TypedEnvelope<proto::Error>>()
754                        .unwrap();
755                    events.lock().push(incoming1.payload.message);
756                    let incoming2 = client_incoming
757                        .next()
758                        .await
759                        .unwrap()
760                        .into_any()
761                        .downcast::<TypedEnvelope<proto::Error>>()
762                        .unwrap();
763                    events.lock().push(incoming2.payload.message);
764
765                    // Prevent the connection from being dropped
766                    client_incoming.next().await;
767                }
768            })
769            .detach();
770
771        response_task.await;
772        assert_eq!(
773            &*events.lock(),
774            &[
775                "message 1".to_string(),
776                "message 2".to_string(),
777                "response".to_string()
778            ]
779        );
780    }
781
782    #[gpui::test(iterations = 50)]
783    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
784        let executor = cx.executor();
785        let server = Peer::new(0);
786        let client = Peer::new(0);
787
788        let (client_to_server_conn, server_to_client_conn, _kill) =
789            Connection::in_memory(cx.executor());
790        let (client_to_server_conn_id, io_task1, mut client_incoming) =
791            client.add_test_connection(client_to_server_conn, cx.executor());
792        let (server_to_client_conn_id, io_task2, mut server_incoming) =
793            server.add_test_connection(server_to_client_conn, cx.executor());
794
795        executor.spawn(io_task1).detach();
796        executor.spawn(io_task2).detach();
797
798        executor
799            .spawn(async move {
800                let request1 = server_incoming
801                    .next()
802                    .await
803                    .unwrap()
804                    .into_any()
805                    .downcast::<TypedEnvelope<proto::Ping>>()
806                    .unwrap();
807                let request2 = server_incoming
808                    .next()
809                    .await
810                    .unwrap()
811                    .into_any()
812                    .downcast::<TypedEnvelope<proto::Ping>>()
813                    .unwrap();
814
815                server
816                    .send(
817                        server_to_client_conn_id,
818                        ErrorCode::Internal
819                            .message("message 1".to_string())
820                            .to_proto(),
821                    )
822                    .unwrap();
823                server
824                    .send(
825                        server_to_client_conn_id,
826                        ErrorCode::Internal
827                            .message("message 2".to_string())
828                            .to_proto(),
829                    )
830                    .unwrap();
831                server.respond(request1.receipt(), proto::Ack {}).unwrap();
832                server.respond(request2.receipt(), proto::Ack {}).unwrap();
833
834                // Prevent the connection from being dropped
835                server_incoming.next().await;
836            })
837            .detach();
838
839        let events = Arc::new(Mutex::new(Vec::new()));
840
841        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
842        let request1_task = executor.spawn(request1);
843        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
844        let request2_task = executor.spawn({
845            let events = events.clone();
846            async move {
847                request2.await.unwrap();
848                events.lock().push("response 2".to_string());
849            }
850        });
851
852        executor
853            .spawn({
854                let events = events.clone();
855                async move {
856                    let incoming1 = client_incoming
857                        .next()
858                        .await
859                        .unwrap()
860                        .into_any()
861                        .downcast::<TypedEnvelope<proto::Error>>()
862                        .unwrap();
863                    events.lock().push(incoming1.payload.message);
864                    let incoming2 = client_incoming
865                        .next()
866                        .await
867                        .unwrap()
868                        .into_any()
869                        .downcast::<TypedEnvelope<proto::Error>>()
870                        .unwrap();
871                    events.lock().push(incoming2.payload.message);
872
873                    // Prevent the connection from being dropped
874                    client_incoming.next().await;
875                }
876            })
877            .detach();
878
879        // Allow the request to make some progress before dropping it.
880        cx.executor().simulate_random_delay().await;
881        drop(request1_task);
882
883        request2_task.await;
884        assert_eq!(
885            &*events.lock(),
886            &[
887                "message 1".to_string(),
888                "message 2".to_string(),
889                "response 2".to_string()
890            ]
891        );
892    }
893
894    #[gpui::test(iterations = 50)]
895    async fn test_disconnect(cx: &mut TestAppContext) {
896        let executor = cx.executor();
897
898        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
899
900        let client = Peer::new(0);
901        let (connection_id, io_handler, mut incoming) =
902            client.add_test_connection(client_conn, executor.clone());
903
904        let (io_ended_tx, io_ended_rx) = oneshot::channel();
905        executor
906            .spawn(async move {
907                io_handler.await.ok();
908                io_ended_tx.send(()).unwrap();
909            })
910            .detach();
911
912        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
913        executor
914            .spawn(async move {
915                incoming.next().await;
916                messages_ended_tx.send(()).unwrap();
917            })
918            .detach();
919
920        client.disconnect(connection_id);
921
922        let _ = io_ended_rx.await;
923        let _ = messages_ended_rx.await;
924        assert!(server_conn
925            .send(WebSocketMessage::Binary(vec![]))
926            .await
927            .is_err());
928    }
929
930    #[gpui::test(iterations = 50)]
931    async fn test_io_error(cx: &mut TestAppContext) {
932        let executor = cx.executor();
933        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
934
935        let client = Peer::new(0);
936        let (connection_id, io_handler, mut incoming) =
937            client.add_test_connection(client_conn, executor.clone());
938        executor.spawn(io_handler).detach();
939        executor
940            .spawn(async move { incoming.next().await })
941            .detach();
942
943        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
944        let _request = server_conn.rx.next().await.unwrap().unwrap();
945
946        drop(server_conn);
947        assert_eq!(
948            response.await.unwrap_err().to_string(),
949            "connection was closed"
950        );
951    }
952}