peer.rs

  1use crate::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
  2
  3use super::{
  4    proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
  5    Connection,
  6};
  7use anyhow::{anyhow, Context, Result};
  8use collections::HashMap;
  9use futures::{
 10    channel::{mpsc, oneshot},
 11    stream::BoxStream,
 12    FutureExt, SinkExt, StreamExt, TryFutureExt,
 13};
 14use parking_lot::{Mutex, RwLock};
 15use serde::{ser::SerializeStruct, Serialize};
 16use std::{fmt, sync::atomic::Ordering::SeqCst, time::Instant};
 17use std::{
 18    future::Future,
 19    marker::PhantomData,
 20    sync::{
 21        atomic::{self, AtomicU32},
 22        Arc,
 23    },
 24    time::Duration,
 25};
 26use tracing::instrument;
 27
 28#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
 29pub struct ConnectionId {
 30    pub owner_id: u32,
 31    pub id: u32,
 32}
 33
 34impl Into<PeerId> for ConnectionId {
 35    fn into(self) -> PeerId {
 36        PeerId {
 37            owner_id: self.owner_id,
 38            id: self.id,
 39        }
 40    }
 41}
 42
 43impl From<PeerId> for ConnectionId {
 44    fn from(peer_id: PeerId) -> Self {
 45        Self {
 46            owner_id: peer_id.owner_id,
 47            id: peer_id.id,
 48        }
 49    }
 50}
 51
 52impl fmt::Display for ConnectionId {
 53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 54        write!(f, "{}/{}", self.owner_id, self.id)
 55    }
 56}
 57
 58pub struct Receipt<T> {
 59    pub sender_id: ConnectionId,
 60    pub message_id: u32,
 61    payload_type: PhantomData<T>,
 62}
 63
 64impl<T> Clone for Receipt<T> {
 65    fn clone(&self) -> Self {
 66        *self
 67    }
 68}
 69
 70impl<T> Copy for Receipt<T> {}
 71
 72#[derive(Clone, Debug)]
 73pub struct TypedEnvelope<T> {
 74    pub sender_id: ConnectionId,
 75    pub original_sender_id: Option<PeerId>,
 76    pub message_id: u32,
 77    pub payload: T,
 78    pub received_at: Instant,
 79}
 80
 81impl<T> TypedEnvelope<T> {
 82    pub fn original_sender_id(&self) -> Result<PeerId> {
 83        self.original_sender_id
 84            .ok_or_else(|| anyhow!("missing original_sender_id"))
 85    }
 86}
 87
 88impl<T: RequestMessage> TypedEnvelope<T> {
 89    pub fn receipt(&self) -> Receipt<T> {
 90        Receipt {
 91            sender_id: self.sender_id,
 92            message_id: self.message_id,
 93            payload_type: PhantomData,
 94        }
 95    }
 96}
 97
 98pub struct Peer {
 99    epoch: AtomicU32,
100    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
101    next_connection_id: AtomicU32,
102}
103
104#[derive(Clone, Serialize)]
105pub struct ConnectionState {
106    #[serde(skip)]
107    outgoing_tx: mpsc::UnboundedSender<proto::Message>,
108    next_message_id: Arc<AtomicU32>,
109    #[allow(clippy::type_complexity)]
110    #[serde(skip)]
111    response_channels: Arc<
112        Mutex<
113            Option<
114                HashMap<
115                    u32,
116                    oneshot::Sender<(proto::Envelope, std::time::Instant, oneshot::Sender<()>)>,
117                >,
118            >,
119        >,
120    >,
121}
122
123const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
124const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
125pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
126
127impl Peer {
128    pub fn new(epoch: u32) -> Arc<Self> {
129        Arc::new(Self {
130            epoch: AtomicU32::new(epoch),
131            connections: Default::default(),
132            next_connection_id: Default::default(),
133        })
134    }
135
136    pub fn epoch(&self) -> u32 {
137        self.epoch.load(SeqCst)
138    }
139
140    #[instrument(skip_all)]
141    pub fn add_connection<F, Fut, Out>(
142        self: &Arc<Self>,
143        connection: Connection,
144        create_timer: F,
145    ) -> (
146        ConnectionId,
147        impl Future<Output = anyhow::Result<()>> + Send,
148        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
149    )
150    where
151        F: Send + Fn(Duration) -> Fut,
152        Fut: Send + Future<Output = Out>,
153        Out: Send,
154    {
155        // For outgoing messages, use an unbounded channel so that application code
156        // can always send messages without yielding. For incoming messages, use a
157        // bounded channel so that other peers will receive backpressure if they send
158        // messages faster than this peer can process them.
159        #[cfg(any(test, feature = "test-support"))]
160        const INCOMING_BUFFER_SIZE: usize = 1;
161        #[cfg(not(any(test, feature = "test-support")))]
162        const INCOMING_BUFFER_SIZE: usize = 256;
163        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
164        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
165
166        let connection_id = ConnectionId {
167            owner_id: self.epoch.load(SeqCst),
168            id: self.next_connection_id.fetch_add(1, SeqCst),
169        };
170        let connection_state = ConnectionState {
171            outgoing_tx,
172            next_message_id: Default::default(),
173            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
174        };
175        let mut writer = MessageStream::new(connection.tx);
176        let mut reader = MessageStream::new(connection.rx);
177
178        let this = self.clone();
179        let response_channels = connection_state.response_channels.clone();
180        let handle_io = async move {
181            tracing::trace!(%connection_id, "handle io future: start");
182
183            let _end_connection = util::defer(|| {
184                response_channels.lock().take();
185                this.connections.write().remove(&connection_id);
186                tracing::trace!(%connection_id, "handle io future: end");
187            });
188
189            // Send messages on this frequency so the connection isn't closed.
190            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
191            futures::pin_mut!(keepalive_timer);
192
193            // Disconnect if we don't receive messages at least this frequently.
194            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
195            futures::pin_mut!(receive_timeout);
196
197            loop {
198                tracing::trace!(%connection_id, "outer loop iteration start");
199                let read_message = reader.read().fuse();
200                futures::pin_mut!(read_message);
201
202                loop {
203                    tracing::trace!(%connection_id, "inner loop iteration start");
204                    futures::select_biased! {
205                        outgoing = outgoing_rx.next().fuse() => match outgoing {
206                            Some(outgoing) => {
207                                tracing::trace!(%connection_id, "outgoing rpc message: writing");
208                                futures::select_biased! {
209                                    result = writer.write(outgoing).fuse() => {
210                                        tracing::trace!(%connection_id, "outgoing rpc message: done writing");
211                                        result.context("failed to write RPC message")?;
212                                        tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
213                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
214                                    }
215                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
216                                        tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
217                                        Err(anyhow!("timed out writing message"))?;
218                                    }
219                                }
220                            }
221                            None => {
222                                tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
223                                return Ok(())
224                            },
225                        },
226                        _ = keepalive_timer => {
227                            tracing::trace!(%connection_id, "keepalive interval: pinging");
228                            futures::select_biased! {
229                                result = writer.write(proto::Message::Ping).fuse() => {
230                                    tracing::trace!(%connection_id, "keepalive interval: done pinging");
231                                    result.context("failed to send keepalive")?;
232                                    tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
233                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
234                                }
235                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
236                                    tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
237                                    Err(anyhow!("timed out sending keepalive"))?;
238                                }
239                            }
240                        }
241                        incoming = read_message => {
242                            let incoming = incoming.context("error reading rpc message from socket")?;
243                            tracing::trace!(%connection_id, "incoming rpc message: received");
244                            tracing::trace!(%connection_id, "receive timeout: resetting");
245                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
246                            if let (proto::Message::Envelope(incoming), received_at) = incoming {
247                                tracing::trace!(%connection_id, "incoming rpc message: processing");
248                                futures::select_biased! {
249                                    result = incoming_tx.send((incoming, received_at)).fuse() => match result {
250                                        Ok(_) => {
251                                            tracing::trace!(%connection_id, "incoming rpc message: processed");
252                                        }
253                                        Err(_) => {
254                                            tracing::trace!(%connection_id, "incoming rpc message: channel closed");
255                                            return Ok(())
256                                        }
257                                    },
258                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
259                                        tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
260                                        Err(anyhow!("timed out processing incoming message"))?
261                                    }
262                                }
263                            }
264                            break;
265                        },
266                        _ = receive_timeout => {
267                            tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
268                            Err(anyhow!("delay between messages too long"))?
269                        }
270                    }
271                }
272            }
273        };
274
275        let response_channels = connection_state.response_channels.clone();
276        self.connections
277            .write()
278            .insert(connection_id, connection_state);
279
280        let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
281            let response_channels = response_channels.clone();
282            async move {
283                let message_id = incoming.id;
284                tracing::trace!(?incoming, "incoming message future: start");
285                let _end = util::defer(move || {
286                    tracing::trace!(%connection_id, message_id, "incoming message future: end");
287                });
288
289                if let Some(responding_to) = incoming.responding_to {
290                    tracing::trace!(
291                        %connection_id,
292                        message_id,
293                        responding_to,
294                        "incoming response: received"
295                    );
296                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
297                    if let Some(tx) = channel {
298                        let requester_resumed = oneshot::channel();
299                        if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
300                            tracing::trace!(
301                                %connection_id,
302                                message_id,
303                                responding_to = responding_to,
304                                ?error,
305                                "incoming response: request future dropped",
306                            );
307                        }
308
309                        tracing::trace!(
310                            %connection_id,
311                            message_id,
312                            responding_to,
313                            "incoming response: waiting to resume requester"
314                        );
315                        let _ = requester_resumed.1.await;
316                        tracing::trace!(
317                            %connection_id,
318                            message_id,
319                            responding_to,
320                            "incoming response: requester resumed"
321                        );
322                    } else {
323                        let message_type =
324                            proto::build_typed_envelope(connection_id, received_at, incoming)
325                                .map(|p| p.payload_type_name());
326                        tracing::warn!(
327                            %connection_id,
328                            message_id,
329                            responding_to,
330                            message_type,
331                            "incoming response: unknown request"
332                        );
333                    }
334
335                    None
336                } else {
337                    tracing::trace!(%connection_id, message_id, "incoming message: received");
338                    proto::build_typed_envelope(connection_id, received_at, incoming).or_else(
339                        || {
340                            tracing::error!(
341                                %connection_id,
342                                message_id,
343                                "unable to construct a typed envelope"
344                            );
345                            None
346                        },
347                    )
348                }
349            }
350        });
351        (connection_id, handle_io, incoming_rx.boxed())
352    }
353
354    #[cfg(any(test, feature = "test-support"))]
355    pub fn add_test_connection(
356        self: &Arc<Self>,
357        connection: Connection,
358        executor: gpui::BackgroundExecutor,
359    ) -> (
360        ConnectionId,
361        impl Future<Output = anyhow::Result<()>> + Send,
362        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
363    ) {
364        let executor = executor.clone();
365        self.add_connection(connection, move |duration| executor.timer(duration))
366    }
367
368    pub fn disconnect(&self, connection_id: ConnectionId) {
369        self.connections.write().remove(&connection_id);
370    }
371
372    #[cfg(any(test, feature = "test-support"))]
373    pub fn reset(&self, epoch: u32) {
374        self.next_connection_id.store(0, SeqCst);
375        self.epoch.store(epoch, SeqCst);
376    }
377
378    pub fn teardown(&self) {
379        self.connections.write().clear();
380    }
381
382    pub fn request<T: RequestMessage>(
383        &self,
384        receiver_id: ConnectionId,
385        request: T,
386    ) -> impl Future<Output = Result<T::Response>> {
387        self.request_internal(None, receiver_id, request)
388            .map_ok(|envelope| envelope.payload)
389    }
390
391    pub fn request_envelope<T: RequestMessage>(
392        &self,
393        receiver_id: ConnectionId,
394        request: T,
395    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
396        self.request_internal(None, receiver_id, request)
397    }
398
399    pub fn forward_request<T: RequestMessage>(
400        &self,
401        sender_id: ConnectionId,
402        receiver_id: ConnectionId,
403        request: T,
404    ) -> impl Future<Output = Result<T::Response>> {
405        self.request_internal(Some(sender_id), receiver_id, request)
406            .map_ok(|envelope| envelope.payload)
407    }
408
409    pub fn request_internal<T: RequestMessage>(
410        &self,
411        original_sender_id: Option<ConnectionId>,
412        receiver_id: ConnectionId,
413        request: T,
414    ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
415        let (tx, rx) = oneshot::channel();
416        let send = self.connection_state(receiver_id).and_then(|connection| {
417            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
418            connection
419                .response_channels
420                .lock()
421                .as_mut()
422                .ok_or_else(|| anyhow!("connection was closed"))?
423                .insert(message_id, tx);
424            connection
425                .outgoing_tx
426                .unbounded_send(proto::Message::Envelope(request.into_envelope(
427                    message_id,
428                    None,
429                    original_sender_id.map(Into::into),
430                )))
431                .map_err(|_| anyhow!("connection was closed"))?;
432            Ok(())
433        });
434        async move {
435            send?;
436            let (response, received_at, _barrier) =
437                rx.await.map_err(|_| anyhow!("connection was closed"))?;
438
439            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
440                Err(RpcError::from_proto(&error, T::NAME))
441            } else {
442                Ok(TypedEnvelope {
443                    message_id: response.id,
444                    sender_id: receiver_id,
445                    original_sender_id: response.original_sender_id,
446                    payload: T::Response::from_envelope(response)
447                        .ok_or_else(|| anyhow!("received response of the wrong type"))?,
448                    received_at,
449                })
450            }
451        }
452    }
453
454    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
455        let connection = self.connection_state(receiver_id)?;
456        let message_id = connection
457            .next_message_id
458            .fetch_add(1, atomic::Ordering::SeqCst);
459        connection
460            .outgoing_tx
461            .unbounded_send(proto::Message::Envelope(
462                message.into_envelope(message_id, None, None),
463            ))?;
464        Ok(())
465    }
466
467    pub fn forward_send<T: EnvelopedMessage>(
468        &self,
469        sender_id: ConnectionId,
470        receiver_id: ConnectionId,
471        message: T,
472    ) -> Result<()> {
473        let connection = self.connection_state(receiver_id)?;
474        let message_id = connection
475            .next_message_id
476            .fetch_add(1, atomic::Ordering::SeqCst);
477        connection
478            .outgoing_tx
479            .unbounded_send(proto::Message::Envelope(message.into_envelope(
480                message_id,
481                None,
482                Some(sender_id.into()),
483            )))?;
484        Ok(())
485    }
486
487    pub fn respond<T: RequestMessage>(
488        &self,
489        receipt: Receipt<T>,
490        response: T::Response,
491    ) -> Result<()> {
492        let connection = self.connection_state(receipt.sender_id)?;
493        let message_id = connection
494            .next_message_id
495            .fetch_add(1, atomic::Ordering::SeqCst);
496        connection
497            .outgoing_tx
498            .unbounded_send(proto::Message::Envelope(response.into_envelope(
499                message_id,
500                Some(receipt.message_id),
501                None,
502            )))?;
503        Ok(())
504    }
505
506    pub fn respond_with_error<T: RequestMessage>(
507        &self,
508        receipt: Receipt<T>,
509        response: proto::Error,
510    ) -> Result<()> {
511        let connection = self.connection_state(receipt.sender_id)?;
512        let message_id = connection
513            .next_message_id
514            .fetch_add(1, atomic::Ordering::SeqCst);
515        connection
516            .outgoing_tx
517            .unbounded_send(proto::Message::Envelope(response.into_envelope(
518                message_id,
519                Some(receipt.message_id),
520                None,
521            )))?;
522        Ok(())
523    }
524
525    pub fn respond_with_unhandled_message(
526        &self,
527        envelope: Box<dyn AnyTypedEnvelope>,
528    ) -> Result<()> {
529        let connection = self.connection_state(envelope.sender_id())?;
530        let response = ErrorCode::Internal
531            .message(format!(
532                "message {} was not handled",
533                envelope.payload_type_name()
534            ))
535            .to_proto();
536        let message_id = connection
537            .next_message_id
538            .fetch_add(1, atomic::Ordering::SeqCst);
539        connection
540            .outgoing_tx
541            .unbounded_send(proto::Message::Envelope(response.into_envelope(
542                message_id,
543                Some(envelope.message_id()),
544                None,
545            )))?;
546        Ok(())
547    }
548
549    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
550        let connections = self.connections.read();
551        let connection = connections
552            .get(&connection_id)
553            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
554        Ok(connection.clone())
555    }
556}
557
558impl Serialize for Peer {
559    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
560    where
561        S: serde::Serializer,
562    {
563        let mut state = serializer.serialize_struct("Peer", 2)?;
564        state.serialize_field("connections", &*self.connections.read())?;
565        state.end()
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572    use crate::TypedEnvelope;
573    use async_tungstenite::tungstenite::Message as WebSocketMessage;
574    use gpui::TestAppContext;
575
576    fn init_logger() {
577        if std::env::var("RUST_LOG").is_ok() {
578            env_logger::init();
579        }
580    }
581
582    #[gpui::test(iterations = 50)]
583    async fn test_request_response(cx: &mut TestAppContext) {
584        init_logger();
585
586        let executor = cx.executor();
587
588        // create 2 clients connected to 1 server
589        let server = Peer::new(0);
590        let client1 = Peer::new(0);
591        let client2 = Peer::new(0);
592
593        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
594            Connection::in_memory(cx.executor());
595        let (client1_conn_id, io_task1, client1_incoming) =
596            client1.add_test_connection(client1_to_server_conn, cx.executor());
597        let (_, io_task2, server_incoming1) =
598            server.add_test_connection(server_to_client_1_conn, cx.executor());
599
600        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
601            Connection::in_memory(cx.executor());
602        let (client2_conn_id, io_task3, client2_incoming) =
603            client2.add_test_connection(client2_to_server_conn, cx.executor());
604        let (_, io_task4, server_incoming2) =
605            server.add_test_connection(server_to_client_2_conn, cx.executor());
606
607        executor.spawn(io_task1).detach();
608        executor.spawn(io_task2).detach();
609        executor.spawn(io_task3).detach();
610        executor.spawn(io_task4).detach();
611        executor
612            .spawn(handle_messages(server_incoming1, server.clone()))
613            .detach();
614        executor
615            .spawn(handle_messages(client1_incoming, client1.clone()))
616            .detach();
617        executor
618            .spawn(handle_messages(server_incoming2, server.clone()))
619            .detach();
620        executor
621            .spawn(handle_messages(client2_incoming, client2.clone()))
622            .detach();
623
624        assert_eq!(
625            client1
626                .request(client1_conn_id, proto::Ping {},)
627                .await
628                .unwrap(),
629            proto::Ack {}
630        );
631
632        assert_eq!(
633            client2
634                .request(client2_conn_id, proto::Ping {},)
635                .await
636                .unwrap(),
637            proto::Ack {}
638        );
639
640        assert_eq!(
641            client1
642                .request(client1_conn_id, proto::Test { id: 1 },)
643                .await
644                .unwrap(),
645            proto::Test { id: 1 }
646        );
647
648        assert_eq!(
649            client2
650                .request(client2_conn_id, proto::Test { id: 2 })
651                .await
652                .unwrap(),
653            proto::Test { id: 2 }
654        );
655
656        client1.disconnect(client1_conn_id);
657        client2.disconnect(client1_conn_id);
658
659        async fn handle_messages(
660            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
661            peer: Arc<Peer>,
662        ) -> Result<()> {
663            while let Some(envelope) = messages.next().await {
664                let envelope = envelope.into_any();
665                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
666                    let receipt = envelope.receipt();
667                    peer.respond(receipt, proto::Ack {})?
668                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
669                {
670                    peer.respond(envelope.receipt(), envelope.payload.clone())?
671                } else {
672                    panic!("unknown message type");
673                }
674            }
675
676            Ok(())
677        }
678    }
679
680    #[gpui::test(iterations = 50)]
681    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
682        let executor = cx.executor();
683        let server = Peer::new(0);
684        let client = Peer::new(0);
685
686        let (client_to_server_conn, server_to_client_conn, _kill) =
687            Connection::in_memory(executor.clone());
688        let (client_to_server_conn_id, io_task1, mut client_incoming) =
689            client.add_test_connection(client_to_server_conn, executor.clone());
690
691        let (server_to_client_conn_id, io_task2, mut server_incoming) =
692            server.add_test_connection(server_to_client_conn, executor.clone());
693
694        executor.spawn(io_task1).detach();
695        executor.spawn(io_task2).detach();
696
697        executor
698            .spawn(async move {
699                let future = server_incoming.next().await;
700                let request = future
701                    .unwrap()
702                    .into_any()
703                    .downcast::<TypedEnvelope<proto::Ping>>()
704                    .unwrap();
705
706                server
707                    .send(
708                        server_to_client_conn_id,
709                        ErrorCode::Internal
710                            .message("message 1".to_string())
711                            .to_proto(),
712                    )
713                    .unwrap();
714                server
715                    .send(
716                        server_to_client_conn_id,
717                        ErrorCode::Internal
718                            .message("message 2".to_string())
719                            .to_proto(),
720                    )
721                    .unwrap();
722                server.respond(request.receipt(), proto::Ack {}).unwrap();
723
724                // Prevent the connection from being dropped
725                server_incoming.next().await;
726            })
727            .detach();
728
729        let events = Arc::new(Mutex::new(Vec::new()));
730
731        let response = client.request(client_to_server_conn_id, proto::Ping {});
732        let response_task = executor.spawn({
733            let events = events.clone();
734            async move {
735                response.await.unwrap();
736                events.lock().push("response".to_string());
737            }
738        });
739
740        executor
741            .spawn({
742                let events = events.clone();
743                async move {
744                    let incoming1 = client_incoming
745                        .next()
746                        .await
747                        .unwrap()
748                        .into_any()
749                        .downcast::<TypedEnvelope<proto::Error>>()
750                        .unwrap();
751                    events.lock().push(incoming1.payload.message);
752                    let incoming2 = client_incoming
753                        .next()
754                        .await
755                        .unwrap()
756                        .into_any()
757                        .downcast::<TypedEnvelope<proto::Error>>()
758                        .unwrap();
759                    events.lock().push(incoming2.payload.message);
760
761                    // Prevent the connection from being dropped
762                    client_incoming.next().await;
763                }
764            })
765            .detach();
766
767        response_task.await;
768        assert_eq!(
769            &*events.lock(),
770            &[
771                "message 1".to_string(),
772                "message 2".to_string(),
773                "response".to_string()
774            ]
775        );
776    }
777
778    #[gpui::test(iterations = 50)]
779    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
780        let executor = cx.executor();
781        let server = Peer::new(0);
782        let client = Peer::new(0);
783
784        let (client_to_server_conn, server_to_client_conn, _kill) =
785            Connection::in_memory(cx.executor());
786        let (client_to_server_conn_id, io_task1, mut client_incoming) =
787            client.add_test_connection(client_to_server_conn, cx.executor());
788        let (server_to_client_conn_id, io_task2, mut server_incoming) =
789            server.add_test_connection(server_to_client_conn, cx.executor());
790
791        executor.spawn(io_task1).detach();
792        executor.spawn(io_task2).detach();
793
794        executor
795            .spawn(async move {
796                let request1 = server_incoming
797                    .next()
798                    .await
799                    .unwrap()
800                    .into_any()
801                    .downcast::<TypedEnvelope<proto::Ping>>()
802                    .unwrap();
803                let request2 = server_incoming
804                    .next()
805                    .await
806                    .unwrap()
807                    .into_any()
808                    .downcast::<TypedEnvelope<proto::Ping>>()
809                    .unwrap();
810
811                server
812                    .send(
813                        server_to_client_conn_id,
814                        ErrorCode::Internal
815                            .message("message 1".to_string())
816                            .to_proto(),
817                    )
818                    .unwrap();
819                server
820                    .send(
821                        server_to_client_conn_id,
822                        ErrorCode::Internal
823                            .message("message 2".to_string())
824                            .to_proto(),
825                    )
826                    .unwrap();
827                server.respond(request1.receipt(), proto::Ack {}).unwrap();
828                server.respond(request2.receipt(), proto::Ack {}).unwrap();
829
830                // Prevent the connection from being dropped
831                server_incoming.next().await;
832            })
833            .detach();
834
835        let events = Arc::new(Mutex::new(Vec::new()));
836
837        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
838        let request1_task = executor.spawn(request1);
839        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
840        let request2_task = executor.spawn({
841            let events = events.clone();
842            async move {
843                request2.await.unwrap();
844                events.lock().push("response 2".to_string());
845            }
846        });
847
848        executor
849            .spawn({
850                let events = events.clone();
851                async move {
852                    let incoming1 = client_incoming
853                        .next()
854                        .await
855                        .unwrap()
856                        .into_any()
857                        .downcast::<TypedEnvelope<proto::Error>>()
858                        .unwrap();
859                    events.lock().push(incoming1.payload.message);
860                    let incoming2 = client_incoming
861                        .next()
862                        .await
863                        .unwrap()
864                        .into_any()
865                        .downcast::<TypedEnvelope<proto::Error>>()
866                        .unwrap();
867                    events.lock().push(incoming2.payload.message);
868
869                    // Prevent the connection from being dropped
870                    client_incoming.next().await;
871                }
872            })
873            .detach();
874
875        // Allow the request to make some progress before dropping it.
876        cx.executor().simulate_random_delay().await;
877        drop(request1_task);
878
879        request2_task.await;
880        assert_eq!(
881            &*events.lock(),
882            &[
883                "message 1".to_string(),
884                "message 2".to_string(),
885                "response 2".to_string()
886            ]
887        );
888    }
889
890    #[gpui::test(iterations = 50)]
891    async fn test_disconnect(cx: &mut TestAppContext) {
892        let executor = cx.executor();
893
894        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
895
896        let client = Peer::new(0);
897        let (connection_id, io_handler, mut incoming) =
898            client.add_test_connection(client_conn, executor.clone());
899
900        let (io_ended_tx, io_ended_rx) = oneshot::channel();
901        executor
902            .spawn(async move {
903                io_handler.await.ok();
904                io_ended_tx.send(()).unwrap();
905            })
906            .detach();
907
908        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
909        executor
910            .spawn(async move {
911                incoming.next().await;
912                messages_ended_tx.send(()).unwrap();
913            })
914            .detach();
915
916        client.disconnect(connection_id);
917
918        let _ = io_ended_rx.await;
919        let _ = messages_ended_rx.await;
920        assert!(server_conn
921            .send(WebSocketMessage::Binary(vec![]))
922            .await
923            .is_err());
924    }
925
926    #[gpui::test(iterations = 50)]
927    async fn test_io_error(cx: &mut TestAppContext) {
928        let executor = cx.executor();
929        let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
930
931        let client = Peer::new(0);
932        let (connection_id, io_handler, mut incoming) =
933            client.add_test_connection(client_conn, executor.clone());
934        executor.spawn(io_handler).detach();
935        executor
936            .spawn(async move { incoming.next().await })
937            .detach();
938
939        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
940        let _request = server_conn.rx.next().await.unwrap().unwrap();
941
942        drop(server_conn);
943        assert_eq!(
944            response.await.unwrap_err().to_string(),
945            "connection was closed"
946        );
947    }
948}