peer.rs

  1use crate::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
  2
  3use super::{
  4    proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
  5    Connection,
  6};
  7use anyhow::{anyhow, Context, Result};
  8use collections::HashMap;
  9use futures::{
 10    channel::{mpsc, oneshot},
 11    stream::BoxStream,
 12    FutureExt, SinkExt, StreamExt, TryFutureExt,
 13};
 14use parking_lot::{Mutex, RwLock};
 15use serde::{ser::SerializeStruct, Serialize};
 16use std::{fmt, sync::atomic::Ordering::SeqCst};
 17use std::{
 18    future::Future,
 19    marker::PhantomData,
 20    sync::{
 21        atomic::{self, AtomicU32},
 22        Arc,
 23    },
 24    time::Duration,
 25};
 26use tracing::instrument;
 27
 28#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
 29pub struct ConnectionId {
 30    pub owner_id: u32,
 31    pub id: u32,
 32}
 33
 34impl Into<PeerId> for ConnectionId {
 35    fn into(self) -> PeerId {
 36        PeerId {
 37            owner_id: self.owner_id,
 38            id: self.id,
 39        }
 40    }
 41}
 42
 43impl From<PeerId> for ConnectionId {
 44    fn from(peer_id: PeerId) -> Self {
 45        Self {
 46            owner_id: peer_id.owner_id,
 47            id: peer_id.id,
 48        }
 49    }
 50}
 51
 52impl fmt::Display for ConnectionId {
 53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 54        write!(f, "{}/{}", self.owner_id, self.id)
 55    }
 56}
 57
 58pub struct Receipt<T> {
 59    pub sender_id: ConnectionId,
 60    pub message_id: u32,
 61    payload_type: PhantomData<T>,
 62}
 63
 64impl<T> Clone for Receipt<T> {
 65    fn clone(&self) -> Self {
 66        Self {
 67            sender_id: self.sender_id,
 68            message_id: self.message_id,
 69            payload_type: PhantomData,
 70        }
 71    }
 72}
 73
 74impl<T> Copy for Receipt<T> {}
 75
 76#[derive(Clone, Debug)]
 77pub struct TypedEnvelope<T> {
 78    pub sender_id: ConnectionId,
 79    pub original_sender_id: Option<PeerId>,
 80    pub message_id: u32,
 81    pub payload: T,
 82}
 83
 84impl<T> TypedEnvelope<T> {
 85    pub fn original_sender_id(&self) -> Result<PeerId> {
 86        self.original_sender_id
 87            .ok_or_else(|| anyhow!("missing original_sender_id"))
 88    }
 89}
 90
 91impl<T: RequestMessage> TypedEnvelope<T> {
 92    pub fn receipt(&self) -> Receipt<T> {
 93        Receipt {
 94            sender_id: self.sender_id,
 95            message_id: self.message_id,
 96            payload_type: PhantomData,
 97        }
 98    }
 99}
100
101pub struct Peer {
102    epoch: AtomicU32,
103    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
104    next_connection_id: AtomicU32,
105}
106
107#[derive(Clone, Serialize)]
108pub struct ConnectionState {
109    #[serde(skip)]
110    outgoing_tx: mpsc::UnboundedSender<proto::Message>,
111    next_message_id: Arc<AtomicU32>,
112    #[allow(clippy::type_complexity)]
113    #[serde(skip)]
114    response_channels:
115        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
116}
117
118const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
119const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
120pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
121
122impl Peer {
123    pub fn new(epoch: u32) -> Arc<Self> {
124        Arc::new(Self {
125            epoch: AtomicU32::new(epoch),
126            connections: Default::default(),
127            next_connection_id: Default::default(),
128        })
129    }
130
131    pub fn epoch(&self) -> u32 {
132        self.epoch.load(SeqCst)
133    }
134
135    #[instrument(skip_all)]
136    pub fn add_connection<F, Fut, Out>(
137        self: &Arc<Self>,
138        connection: Connection,
139        create_timer: F,
140    ) -> (
141        ConnectionId,
142        impl Future<Output = anyhow::Result<()>> + Send,
143        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
144    )
145    where
146        F: Send + Fn(Duration) -> Fut,
147        Fut: Send + Future<Output = Out>,
148        Out: Send,
149    {
150        // For outgoing messages, use an unbounded channel so that application code
151        // can always send messages without yielding. For incoming messages, use a
152        // bounded channel so that other peers will receive backpressure if they send
153        // messages faster than this peer can process them.
154        #[cfg(any(test, feature = "test-support"))]
155        const INCOMING_BUFFER_SIZE: usize = 1;
156        #[cfg(not(any(test, feature = "test-support")))]
157        const INCOMING_BUFFER_SIZE: usize = 64;
158        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
159        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
160
161        let connection_id = ConnectionId {
162            owner_id: self.epoch.load(SeqCst),
163            id: self.next_connection_id.fetch_add(1, SeqCst),
164        };
165        let connection_state = ConnectionState {
166            outgoing_tx,
167            next_message_id: Default::default(),
168            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
169        };
170        let mut writer = MessageStream::new(connection.tx);
171        let mut reader = MessageStream::new(connection.rx);
172
173        let this = self.clone();
174        let response_channels = connection_state.response_channels.clone();
175        let handle_io = async move {
176            tracing::trace!(%connection_id, "handle io future: start");
177
178            let _end_connection = util::defer(|| {
179                response_channels.lock().take();
180                this.connections.write().remove(&connection_id);
181                tracing::trace!(%connection_id, "handle io future: end");
182            });
183
184            // Send messages on this frequency so the connection isn't closed.
185            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
186            futures::pin_mut!(keepalive_timer);
187
188            // Disconnect if we don't receive messages at least this frequently.
189            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
190            futures::pin_mut!(receive_timeout);
191
192            loop {
193                tracing::trace!(%connection_id, "outer loop iteration start");
194                let read_message = reader.read().fuse();
195                futures::pin_mut!(read_message);
196
197                loop {
198                    tracing::trace!(%connection_id, "inner loop iteration start");
199                    futures::select_biased! {
200                        outgoing = outgoing_rx.next().fuse() => match outgoing {
201                            Some(outgoing) => {
202                                tracing::trace!(%connection_id, "outgoing rpc message: writing");
203                                futures::select_biased! {
204                                    result = writer.write(outgoing).fuse() => {
205                                        tracing::trace!(%connection_id, "outgoing rpc message: done writing");
206                                        result.context("failed to write RPC message")?;
207                                        tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
208                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
209                                    }
210                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
211                                        tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
212                                        Err(anyhow!("timed out writing message"))?;
213                                    }
214                                }
215                            }
216                            None => {
217                                tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
218                                return Ok(())
219                            },
220                        },
221                        _ = keepalive_timer => {
222                            tracing::trace!(%connection_id, "keepalive interval: pinging");
223                            futures::select_biased! {
224                                result = writer.write(proto::Message::Ping).fuse() => {
225                                    tracing::trace!(%connection_id, "keepalive interval: done pinging");
226                                    result.context("failed to send keepalive")?;
227                                    tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
228                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
229                                }
230                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
231                                    tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
232                                    Err(anyhow!("timed out sending keepalive"))?;
233                                }
234                            }
235                        }
236                        incoming = read_message => {
237                            let incoming = incoming.context("error reading rpc message from socket")?;
238                            tracing::trace!(%connection_id, "incoming rpc message: received");
239                            tracing::trace!(%connection_id, "receive timeout: resetting");
240                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
241                            if let proto::Message::Envelope(incoming) = incoming {
242                                tracing::trace!(%connection_id, "incoming rpc message: processing");
243                                futures::select_biased! {
244                                    result = incoming_tx.send(incoming).fuse() => match result {
245                                        Ok(_) => {
246                                            tracing::trace!(%connection_id, "incoming rpc message: processed");
247                                        }
248                                        Err(_) => {
249                                            tracing::trace!(%connection_id, "incoming rpc message: channel closed");
250                                            return Ok(())
251                                        }
252                                    },
253                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
254                                        tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
255                                        Err(anyhow!("timed out processing incoming message"))?
256                                    }
257                                }
258                            }
259                            break;
260                        },
261                        _ = receive_timeout => {
262                            tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
263                            Err(anyhow!("delay between messages too long"))?
264                        }
265                    }
266                }
267            }
268        };
269
270        let response_channels = connection_state.response_channels.clone();
271        self.connections
272            .write()
273            .insert(connection_id, connection_state);
274
275        let incoming_rx = incoming_rx.filter_map(move |incoming| {
276            let response_channels = response_channels.clone();
277            async move {
278                let message_id = incoming.id;
279                tracing::trace!(?incoming, "incoming message future: start");
280                let _end = util::defer(move || {
281                    tracing::trace!(%connection_id, message_id, "incoming message future: end");
282                });
283
284                if let Some(responding_to) = incoming.responding_to {
285                    tracing::trace!(
286                        %connection_id,
287                        message_id,
288                        responding_to,
289                        "incoming response: received"
290                    );
291                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
292                    if let Some(tx) = channel {
293                        let requester_resumed = oneshot::channel();
294                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
295                            tracing::trace!(
296                                %connection_id,
297                                message_id,
298                                responding_to = responding_to,
299                                ?error,
300                                "incoming response: request future dropped",
301                            );
302                        }
303
304                        tracing::trace!(
305                            %connection_id,
306                            message_id,
307                            responding_to,
308                            "incoming response: waiting to resume requester"
309                        );
310                        let _ = requester_resumed.1.await;
311                        tracing::trace!(
312                            %connection_id,
313                            message_id,
314                            responding_to,
315                            "incoming response: requester resumed"
316                        );
317                    } else {
318                        let message_type = proto::build_typed_envelope(connection_id, incoming)
319                            .map(|p| p.payload_type_name());
320                        tracing::warn!(
321                            %connection_id,
322                            message_id,
323                            responding_to,
324                            message_type,
325                            "incoming response: unknown request"
326                        );
327                    }
328
329                    None
330                } else {
331                    tracing::trace!(%connection_id, message_id, "incoming message: received");
332                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
333                        tracing::error!(
334                            %connection_id,
335                            message_id,
336                            "unable to construct a typed envelope"
337                        );
338                        None
339                    })
340                }
341            }
342        });
343        (connection_id, handle_io, incoming_rx.boxed())
344    }
345
346    #[cfg(any(test, feature = "test-support"))]
347    pub fn add_test_connection(
348        self: &Arc<Self>,
349        connection: Connection,
350        executor: gpui::BackgroundExecutor,
351    ) -> (
352        ConnectionId,
353        impl Future<Output = anyhow::Result<()>> + Send,
354        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
355    ) {
356        let executor = executor.clone();
357        self.add_connection(connection, move |duration| executor.timer(duration))
358    }
359
360    pub fn disconnect(&self, connection_id: ConnectionId) {
361        self.connections.write().remove(&connection_id);
362    }
363
364    #[cfg(any(test, feature = "test-support"))]
365    pub fn reset(&self, epoch: u32) {
366        self.next_connection_id.store(0, SeqCst);
367        self.epoch.store(epoch, SeqCst);
368    }
369
370    pub fn teardown(&self) {
371        self.connections.write().clear();
372    }
373
374    pub fn request<T: RequestMessage>(
375        &self,
376        receiver_id: ConnectionId,
377        request: T,
378    ) -> impl Future<Output = Result<T::Response>> {
379        self.request_internal(None, receiver_id, request)
380            .map_ok(|envelope| envelope.payload)
381    }
382
383    pub fn request_envelope<T: RequestMessage>(
384        &self,
385        receiver_id: ConnectionId,
386        request: T,
387    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
388        self.request_internal(None, receiver_id, request)
389    }
390
391    pub fn forward_request<T: RequestMessage>(
392        &self,
393        sender_id: ConnectionId,
394        receiver_id: ConnectionId,
395        request: T,
396    ) -> impl Future<Output = Result<T::Response>> {
397        self.request_internal(Some(sender_id), receiver_id, request)
398            .map_ok(|envelope| envelope.payload)
399    }
400
401    pub fn request_internal<T: RequestMessage>(
402        &self,
403        original_sender_id: Option<ConnectionId>,
404        receiver_id: ConnectionId,
405        request: T,
406    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
407        let (tx, rx) = oneshot::channel();
408        let send = self.connection_state(receiver_id).and_then(|connection| {
409            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
410            connection
411                .response_channels
412                .lock()
413                .as_mut()
414                .ok_or_else(|| anyhow!("connection was closed"))?
415                .insert(message_id, tx);
416            connection
417                .outgoing_tx
418                .unbounded_send(proto::Message::Envelope(request.into_envelope(
419                    message_id,
420                    None,
421                    original_sender_id.map(Into::into),
422                )))
423                .map_err(|_| anyhow!("connection was closed"))?;
424            Ok(())
425        });
426        async move {
427            send?;
428            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
429
430            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
431                Err(RpcError::from_proto(&error, T::NAME))
432            } else {
433                Ok(TypedEnvelope {
434                    message_id: response.id,
435                    sender_id: receiver_id,
436                    original_sender_id: response.original_sender_id,
437                    payload: T::Response::from_envelope(response)
438                        .ok_or_else(|| anyhow!("received response of the wrong type"))?,
439                })
440            }
441        }
442    }
443
444    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
445        let connection = self.connection_state(receiver_id)?;
446        let message_id = connection
447            .next_message_id
448            .fetch_add(1, atomic::Ordering::SeqCst);
449        connection
450            .outgoing_tx
451            .unbounded_send(proto::Message::Envelope(
452                message.into_envelope(message_id, None, None),
453            ))?;
454        Ok(())
455    }
456
457    pub fn forward_send<T: EnvelopedMessage>(
458        &self,
459        sender_id: ConnectionId,
460        receiver_id: ConnectionId,
461        message: T,
462    ) -> Result<()> {
463        let connection = self.connection_state(receiver_id)?;
464        let message_id = connection
465            .next_message_id
466            .fetch_add(1, atomic::Ordering::SeqCst);
467        connection
468            .outgoing_tx
469            .unbounded_send(proto::Message::Envelope(message.into_envelope(
470                message_id,
471                None,
472                Some(sender_id.into()),
473            )))?;
474        Ok(())
475    }
476
477    pub fn respond<T: RequestMessage>(
478        &self,
479        receipt: Receipt<T>,
480        response: T::Response,
481    ) -> Result<()> {
482        let connection = self.connection_state(receipt.sender_id)?;
483        let message_id = connection
484            .next_message_id
485            .fetch_add(1, atomic::Ordering::SeqCst);
486        connection
487            .outgoing_tx
488            .unbounded_send(proto::Message::Envelope(response.into_envelope(
489                message_id,
490                Some(receipt.message_id),
491                None,
492            )))?;
493        Ok(())
494    }
495
496    pub fn respond_with_error<T: RequestMessage>(
497        &self,
498        receipt: Receipt<T>,
499        response: proto::Error,
500    ) -> Result<()> {
501        let connection = self.connection_state(receipt.sender_id)?;
502        let message_id = connection
503            .next_message_id
504            .fetch_add(1, atomic::Ordering::SeqCst);
505        connection
506            .outgoing_tx
507            .unbounded_send(proto::Message::Envelope(response.into_envelope(
508                message_id,
509                Some(receipt.message_id),
510                None,
511            )))?;
512        Ok(())
513    }
514
515    pub fn respond_with_unhandled_message(
516        &self,
517        envelope: Box<dyn AnyTypedEnvelope>,
518    ) -> Result<()> {
519        let connection = self.connection_state(envelope.sender_id())?;
520        let response = ErrorCode::Internal
521            .message(format!(
522                "message {} was not handled",
523                envelope.payload_type_name()
524            ))
525            .to_proto();
526        let message_id = connection
527            .next_message_id
528            .fetch_add(1, atomic::Ordering::SeqCst);
529        connection
530            .outgoing_tx
531            .unbounded_send(proto::Message::Envelope(response.into_envelope(
532                message_id,
533                Some(envelope.message_id()),
534                None,
535            )))?;
536        Ok(())
537    }
538
539    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
540        let connections = self.connections.read();
541        let connection = connections
542            .get(&connection_id)
543            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
544        Ok(connection.clone())
545    }
546}
547
548impl Serialize for Peer {
549    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
550    where
551        S: serde::Serializer,
552    {
553        let mut state = serializer.serialize_struct("Peer", 2)?;
554        state.serialize_field("connections", &*self.connections.read())?;
555        state.end()
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use crate::TypedEnvelope;
563    use async_tungstenite::tungstenite::Message as WebSocketMessage;
564    use gpui::TestAppContext;
565
566    fn init_logger() {
567        if std::env::var("RUST_LOG").is_ok() {
568            env_logger::init();
569        }
570    }
571
572    #[gpui::test(iterations = 50)]
573    async fn test_request_response(cx: &mut TestAppContext) {
574        init_logger();
575
576        let executor = cx.executor();
577
578        // create 2 clients connected to 1 server
579        let server = Peer::new(0);
580        let client1 = Peer::new(0);
581        let client2 = Peer::new(0);
582
583        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
584            Connection::in_memory(cx.executor());
585        let (client1_conn_id, io_task1, client1_incoming) =
586            client1.add_test_connection(client1_to_server_conn, cx.executor());
587        let (_, io_task2, server_incoming1) =
588            server.add_test_connection(server_to_client_1_conn, cx.executor());
589
590        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
591            Connection::in_memory(cx.executor());
592        let (client2_conn_id, io_task3, client2_incoming) =
593            client2.add_test_connection(client2_to_server_conn, cx.executor());
594        let (_, io_task4, server_incoming2) =
595            server.add_test_connection(server_to_client_2_conn, cx.executor());
596
597        executor.spawn(io_task1).detach();
598        executor.spawn(io_task2).detach();
599        executor.spawn(io_task3).detach();
600        executor.spawn(io_task4).detach();
601        executor
602            .spawn(handle_messages(server_incoming1, server.clone()))
603            .detach();
604        executor
605            .spawn(handle_messages(client1_incoming, client1.clone()))
606            .detach();
607        executor
608            .spawn(handle_messages(server_incoming2, server.clone()))
609            .detach();
610        executor
611            .spawn(handle_messages(client2_incoming, client2.clone()))
612            .detach();
613
614        assert_eq!(
615            client1
616                .request(client1_conn_id, proto::Ping {},)
617                .await
618                .unwrap(),
619            proto::Ack {}
620        );
621
622        assert_eq!(
623            client2
624                .request(client2_conn_id, proto::Ping {},)
625                .await
626                .unwrap(),
627            proto::Ack {}
628        );
629
630        assert_eq!(
631            client1
632                .request(client1_conn_id, proto::Test { id: 1 },)
633                .await
634                .unwrap(),
635            proto::Test { id: 1 }
636        );
637
638        assert_eq!(
639            client2
640                .request(client2_conn_id, proto::Test { id: 2 })
641                .await
642                .unwrap(),
643            proto::Test { id: 2 }
644        );
645
646        client1.disconnect(client1_conn_id);
647        client2.disconnect(client1_conn_id);
648
649        async fn handle_messages(
650            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
651            peer: Arc<Peer>,
652        ) -> Result<()> {
653            while let Some(envelope) = messages.next().await {
654                let envelope = envelope.into_any();
655                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
656                    let receipt = envelope.receipt();
657                    peer.respond(receipt, proto::Ack {})?
658                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
659                {
660                    peer.respond(envelope.receipt(), envelope.payload.clone())?
661                } else {
662                    panic!("unknown message type");
663                }
664            }
665
666            Ok(())
667        }
668    }
669
670    #[gpui::test(iterations = 50)]
671    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
672        let executor = cx.executor();
673        let server = Peer::new(0);
674        let client = Peer::new(0);
675
676        let (client_to_server_conn, server_to_client_conn, _kill) =
677            Connection::in_memory(executor.clone());
678        let (client_to_server_conn_id, io_task1, mut client_incoming) =
679            client.add_test_connection(client_to_server_conn, executor.clone());
680
681        let (server_to_client_conn_id, io_task2, mut server_incoming) =
682            server.add_test_connection(server_to_client_conn, executor.clone());
683
684        executor.spawn(io_task1).detach();
685        executor.spawn(io_task2).detach();
686
687        executor
688            .spawn(async move {
689                let future = server_incoming.next().await;
690                let request = future
691                    .unwrap()
692                    .into_any()
693                    .downcast::<TypedEnvelope<proto::Ping>>()
694                    .unwrap();
695
696                server
697                    .send(
698                        server_to_client_conn_id,
699                        ErrorCode::Internal
700                            .message("message 1".to_string())
701                            .to_proto(),
702                    )
703                    .unwrap();
704                server
705                    .send(
706                        server_to_client_conn_id,
707                        ErrorCode::Internal
708                            .message("message 2".to_string())
709                            .to_proto(),
710                    )
711                    .unwrap();
712                server.respond(request.receipt(), proto::Ack {}).unwrap();
713
714                // Prevent the connection from being dropped
715                server_incoming.next().await;
716            })
717            .detach();
718
719        let events = Arc::new(Mutex::new(Vec::new()));
720
721        let response = client.request(client_to_server_conn_id, proto::Ping {});
722        let response_task = executor.spawn({
723            let events = events.clone();
724            async move {
725                response.await.unwrap();
726                events.lock().push("response".to_string());
727            }
728        });
729
730        executor
731            .spawn({
732                let events = events.clone();
733                async move {
734                    let incoming1 = client_incoming
735                        .next()
736                        .await
737                        .unwrap()
738                        .into_any()
739                        .downcast::<TypedEnvelope<proto::Error>>()
740                        .unwrap();
741                    events.lock().push(incoming1.payload.message);
742                    let incoming2 = client_incoming
743                        .next()
744                        .await
745                        .unwrap()
746                        .into_any()
747                        .downcast::<TypedEnvelope<proto::Error>>()
748                        .unwrap();
749                    events.lock().push(incoming2.payload.message);
750
751                    // Prevent the connection from being dropped
752                    client_incoming.next().await;
753                }
754            })
755            .detach();
756
757        response_task.await;
758        assert_eq!(
759            &*events.lock(),
760            &[
761                "message 1".to_string(),
762                "message 2".to_string(),
763                "response".to_string()
764            ]
765        );
766    }
767
768    #[gpui::test(iterations = 50)]
769    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
770        let executor = cx.executor();
771        let server = Peer::new(0);
772        let client = Peer::new(0);
773
774        let (client_to_server_conn, server_to_client_conn, _kill) =
775            Connection::in_memory(cx.executor());
776        let (client_to_server_conn_id, io_task1, mut client_incoming) =
777            client.add_test_connection(client_to_server_conn, cx.executor());
778        let (server_to_client_conn_id, io_task2, mut server_incoming) =
779            server.add_test_connection(server_to_client_conn, cx.executor());
780
781        executor.spawn(io_task1).detach();
782        executor.spawn(io_task2).detach();
783
784        executor
785            .spawn(async move {
786                let request1 = server_incoming
787                    .next()
788                    .await
789                    .unwrap()
790                    .into_any()
791                    .downcast::<TypedEnvelope<proto::Ping>>()
792                    .unwrap();
793                let request2 = server_incoming
794                    .next()
795                    .await
796                    .unwrap()
797                    .into_any()
798                    .downcast::<TypedEnvelope<proto::Ping>>()
799                    .unwrap();
800
801                server
802                    .send(
803                        server_to_client_conn_id,
804                        ErrorCode::Internal
805                            .message("message 1".to_string())
806                            .to_proto(),
807                    )
808                    .unwrap();
809                server
810                    .send(
811                        server_to_client_conn_id,
812                        ErrorCode::Internal
813                            .message("message 2".to_string())
814                            .to_proto(),
815                    )
816                    .unwrap();
817                server.respond(request1.receipt(), proto::Ack {}).unwrap();
818                server.respond(request2.receipt(), proto::Ack {}).unwrap();
819
820                // Prevent the connection from being dropped
821                server_incoming.next().await;
822            })
823            .detach();
824
825        let events = Arc::new(Mutex::new(Vec::new()));
826
827        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
828        let request1_task = executor.spawn(request1);
829        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
830        let request2_task = executor.spawn({
831            let events = events.clone();
832            async move {
833                request2.await.unwrap();
834                events.lock().push("response 2".to_string());
835            }
836        });
837
838        executor
839            .spawn({
840                let events = events.clone();
841                async move {
842                    let incoming1 = client_incoming
843                        .next()
844                        .await
845                        .unwrap()
846                        .into_any()
847                        .downcast::<TypedEnvelope<proto::Error>>()
848                        .unwrap();
849                    events.lock().push(incoming1.payload.message);
850                    let incoming2 = client_incoming
851                        .next()
852                        .await
853                        .unwrap()
854                        .into_any()
855                        .downcast::<TypedEnvelope<proto::Error>>()
856                        .unwrap();
857                    events.lock().push(incoming2.payload.message);
858
859                    // Prevent the connection from being dropped
860                    client_incoming.next().await;
861                }
862            })
863            .detach();
864
865        // Allow the request to make some progress before dropping it.
866        cx.executor().simulate_random_delay().await;
867        drop(request1_task);
868
869        request2_task.await;
870        assert_eq!(
871            &*events.lock(),
872            &[
873                "message 1".to_string(),
874                "message 2".to_string(),
875                "response 2".to_string()
876            ]
877        );
878    }
879
880    #[gpui::test(iterations = 50)]
881    async fn test_disconnect(cx: &mut TestAppContext) {
882        let executor = cx.executor();
883
884        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
885
886        let client = Peer::new(0);
887        let (connection_id, io_handler, mut incoming) =
888            client.add_test_connection(client_conn, executor.clone());
889
890        let (io_ended_tx, io_ended_rx) = oneshot::channel();
891        executor
892            .spawn(async move {
893                io_handler.await.ok();
894                io_ended_tx.send(()).unwrap();
895            })
896            .detach();
897
898        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
899        executor
900            .spawn(async move {
901                incoming.next().await;
902                messages_ended_tx.send(()).unwrap();
903            })
904            .detach();
905
906        client.disconnect(connection_id);
907
908        let _ = io_ended_rx.await;
909        let _ = messages_ended_rx.await;
910        assert!(server_conn
911            .send(WebSocketMessage::Binary(vec![]))
912            .await
913            .is_err());
914    }
915
916    #[gpui::test(iterations = 50)]
917    async fn test_io_error(cx: &mut TestAppContext) {
918        let executor = cx.executor();
919        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
920
921        let client = Peer::new(0);
922        let (connection_id, io_handler, mut incoming) =
923            client.add_test_connection(client_conn, executor.clone());
924        executor.spawn(io_handler).detach();
925        executor
926            .spawn(async move { incoming.next().await })
927            .detach();
928
929        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
930        let _request = server_conn.rx.next().await.unwrap().unwrap();
931
932        drop(server_conn);
933        assert_eq!(
934            response.await.unwrap_err().to_string(),
935            "connection was closed"
936        );
937    }
938}