peer.rs

  1use super::{
  2    proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
  3    Connection,
  4};
  5use anyhow::{anyhow, Context, Result};
  6use collections::HashMap;
  7use futures::{
  8    channel::{mpsc, oneshot},
  9    stream::BoxStream,
 10    FutureExt, SinkExt, StreamExt,
 11};
 12use parking_lot::{Mutex, RwLock};
 13use serde::{ser::SerializeStruct, Serialize};
 14use std::{fmt, sync::atomic::Ordering::SeqCst};
 15use std::{
 16    future::Future,
 17    marker::PhantomData,
 18    sync::{
 19        atomic::{self, AtomicU32},
 20        Arc,
 21    },
 22    time::Duration,
 23};
 24use tracing::instrument;
 25
 26#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
 27pub struct ConnectionId {
 28    pub owner_id: u32,
 29    pub id: u32,
 30}
 31
 32impl Into<PeerId> for ConnectionId {
 33    fn into(self) -> PeerId {
 34        PeerId {
 35            owner_id: self.owner_id,
 36            id: self.id,
 37        }
 38    }
 39}
 40
 41impl From<PeerId> for ConnectionId {
 42    fn from(peer_id: PeerId) -> Self {
 43        Self {
 44            owner_id: peer_id.owner_id,
 45            id: peer_id.id,
 46        }
 47    }
 48}
 49
 50impl fmt::Display for ConnectionId {
 51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 52        write!(f, "{}/{}", self.owner_id, self.id)
 53    }
 54}
 55
 56pub struct Receipt<T> {
 57    pub sender_id: ConnectionId,
 58    pub message_id: u32,
 59    payload_type: PhantomData<T>,
 60}
 61
 62impl<T> Clone for Receipt<T> {
 63    fn clone(&self) -> Self {
 64        Self {
 65            sender_id: self.sender_id,
 66            message_id: self.message_id,
 67            payload_type: PhantomData,
 68        }
 69    }
 70}
 71
 72impl<T> Copy for Receipt<T> {}
 73
 74pub struct TypedEnvelope<T> {
 75    pub sender_id: ConnectionId,
 76    pub original_sender_id: Option<PeerId>,
 77    pub message_id: u32,
 78    pub payload: T,
 79}
 80
 81impl<T> TypedEnvelope<T> {
 82    pub fn original_sender_id(&self) -> Result<PeerId> {
 83        self.original_sender_id
 84            .ok_or_else(|| anyhow!("missing original_sender_id"))
 85    }
 86}
 87
 88impl<T: RequestMessage> TypedEnvelope<T> {
 89    pub fn receipt(&self) -> Receipt<T> {
 90        Receipt {
 91            sender_id: self.sender_id,
 92            message_id: self.message_id,
 93            payload_type: PhantomData,
 94        }
 95    }
 96}
 97
 98pub struct Peer {
 99    epoch: AtomicU32,
100    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
101    next_connection_id: AtomicU32,
102}
103
104#[derive(Clone, Serialize)]
105pub struct ConnectionState {
106    #[serde(skip)]
107    outgoing_tx: mpsc::UnboundedSender<proto::Message>,
108    next_message_id: Arc<AtomicU32>,
109    #[allow(clippy::type_complexity)]
110    #[serde(skip)]
111    response_channels:
112        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
113}
114
115const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
116const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
117pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
118
119impl Peer {
120    pub fn new(epoch: u32) -> Arc<Self> {
121        Arc::new(Self {
122            epoch: AtomicU32::new(epoch),
123            connections: Default::default(),
124            next_connection_id: Default::default(),
125        })
126    }
127
128    pub fn epoch(&self) -> u32 {
129        self.epoch.load(SeqCst)
130    }
131
132    #[instrument(skip_all)]
133    pub fn add_connection<F, Fut, Out>(
134        self: &Arc<Self>,
135        connection: Connection,
136        create_timer: F,
137    ) -> (
138        ConnectionId,
139        impl Future<Output = anyhow::Result<()>> + Send,
140        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
141    )
142    where
143        F: Send + Fn(Duration) -> Fut,
144        Fut: Send + Future<Output = Out>,
145        Out: Send,
146    {
147        // For outgoing messages, use an unbounded channel so that application code
148        // can always send messages without yielding. For incoming messages, use a
149        // bounded channel so that other peers will receive backpressure if they send
150        // messages faster than this peer can process them.
151        #[cfg(any(test, feature = "test-support"))]
152        const INCOMING_BUFFER_SIZE: usize = 1;
153        #[cfg(not(any(test, feature = "test-support")))]
154        const INCOMING_BUFFER_SIZE: usize = 64;
155        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
156        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
157
158        let connection_id = ConnectionId {
159            owner_id: self.epoch.load(SeqCst),
160            id: self.next_connection_id.fetch_add(1, SeqCst),
161        };
162        let connection_state = ConnectionState {
163            outgoing_tx,
164            next_message_id: Default::default(),
165            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
166        };
167        let mut writer = MessageStream::new(connection.tx);
168        let mut reader = MessageStream::new(connection.rx);
169
170        let this = self.clone();
171        let response_channels = connection_state.response_channels.clone();
172        let handle_io = async move {
173            tracing::debug!(%connection_id, "handle io future: start");
174
175            let _end_connection = util::defer(|| {
176                response_channels.lock().take();
177                this.connections.write().remove(&connection_id);
178                tracing::debug!(%connection_id, "handle io future: end");
179            });
180
181            // Send messages on this frequency so the connection isn't closed.
182            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
183            futures::pin_mut!(keepalive_timer);
184
185            // Disconnect if we don't receive messages at least this frequently.
186            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
187            futures::pin_mut!(receive_timeout);
188
189            loop {
190                tracing::debug!(%connection_id, "outer loop iteration start");
191                let read_message = reader.read().fuse();
192                futures::pin_mut!(read_message);
193
194                loop {
195                    tracing::debug!(%connection_id, "inner loop iteration start");
196                    futures::select_biased! {
197                        outgoing = outgoing_rx.next().fuse() => match outgoing {
198                            Some(outgoing) => {
199                                tracing::debug!(%connection_id, "outgoing rpc message: writing");
200                                futures::select_biased! {
201                                    result = writer.write(outgoing).fuse() => {
202                                        tracing::debug!(%connection_id, "outgoing rpc message: done writing");
203                                        result.context("failed to write RPC message")?;
204                                        tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
205                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
206                                    }
207                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
208                                        tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
209                                        Err(anyhow!("timed out writing message"))?;
210                                    }
211                                }
212                            }
213                            None => {
214                                tracing::debug!(%connection_id, "outgoing rpc message: channel closed");
215                                return Ok(())
216                            },
217                        },
218                        _ = keepalive_timer => {
219                            tracing::debug!(%connection_id, "keepalive interval: pinging");
220                            futures::select_biased! {
221                                result = writer.write(proto::Message::Ping).fuse() => {
222                                    tracing::debug!(%connection_id, "keepalive interval: done pinging");
223                                    result.context("failed to send keepalive")?;
224                                    tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
225                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
226                                }
227                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
228                                    tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
229                                    Err(anyhow!("timed out sending keepalive"))?;
230                                }
231                            }
232                        }
233                        incoming = read_message => {
234                            let incoming = incoming.context("error reading rpc message from socket")?;
235                            tracing::debug!(%connection_id, "incoming rpc message: received");
236                            tracing::debug!(%connection_id, "receive timeout: resetting");
237                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
238                            if let proto::Message::Envelope(incoming) = incoming {
239                                tracing::debug!(%connection_id, "incoming rpc message: processing");
240                                futures::select_biased! {
241                                    result = incoming_tx.send(incoming).fuse() => match result {
242                                        Ok(_) => {
243                                            tracing::debug!(%connection_id, "incoming rpc message: processed");
244                                        }
245                                        Err(_) => {
246                                            tracing::debug!(%connection_id, "incoming rpc message: channel closed");
247                                            return Ok(())
248                                        }
249                                    },
250                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
251                                        tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
252                                        Err(anyhow!("timed out processing incoming message"))?
253                                    }
254                                }
255                            }
256                            break;
257                        },
258                        _ = receive_timeout => {
259                            tracing::debug!(%connection_id, "receive timeout: delay between messages too long");
260                            Err(anyhow!("delay between messages too long"))?
261                        }
262                    }
263                }
264            }
265        };
266
267        let response_channels = connection_state.response_channels.clone();
268        self.connections
269            .write()
270            .insert(connection_id, connection_state);
271
272        let incoming_rx = incoming_rx.filter_map(move |incoming| {
273            let response_channels = response_channels.clone();
274            async move {
275                let message_id = incoming.id;
276                tracing::debug!(?incoming, "incoming message future: start");
277                let _end = util::defer(move || {
278                    tracing::debug!(%connection_id, message_id, "incoming message future: end");
279                });
280
281                if let Some(responding_to) = incoming.responding_to {
282                    tracing::debug!(
283                        %connection_id,
284                        message_id,
285                        responding_to,
286                        "incoming response: received"
287                    );
288                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
289                    if let Some(tx) = channel {
290                        let requester_resumed = oneshot::channel();
291                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
292                            tracing::debug!(
293                                %connection_id,
294                                message_id,
295                                responding_to = responding_to,
296                                ?error,
297                                "incoming response: request future dropped",
298                            );
299                        }
300
301                        tracing::debug!(
302                            %connection_id,
303                            message_id,
304                            responding_to,
305                            "incoming response: waiting to resume requester"
306                        );
307                        let _ = requester_resumed.1.await;
308                        tracing::debug!(
309                            %connection_id,
310                            message_id,
311                            responding_to,
312                            "incoming response: requester resumed"
313                        );
314                    } else {
315                        tracing::warn!(
316                            %connection_id,
317                            message_id,
318                            responding_to,
319                            "incoming response: unknown request"
320                        );
321                    }
322
323                    None
324                } else {
325                    tracing::debug!(%connection_id, message_id, "incoming message: received");
326                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
327                        tracing::error!(
328                            %connection_id,
329                            message_id,
330                            "unable to construct a typed envelope"
331                        );
332                        None
333                    })
334                }
335            }
336        });
337        (connection_id, handle_io, incoming_rx.boxed())
338    }
339
340    #[cfg(any(test, feature = "test-support"))]
341    pub fn add_test_connection(
342        self: &Arc<Self>,
343        connection: Connection,
344        executor: Arc<gpui::executor::Background>,
345    ) -> (
346        ConnectionId,
347        impl Future<Output = anyhow::Result<()>> + Send,
348        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
349    ) {
350        let executor = executor.clone();
351        self.add_connection(connection, move |duration| executor.timer(duration))
352    }
353
354    pub fn disconnect(&self, connection_id: ConnectionId) {
355        self.connections.write().remove(&connection_id);
356    }
357
358    pub fn reset(&self, epoch: u32) {
359        self.teardown();
360        self.next_connection_id.store(0, SeqCst);
361        self.epoch.store(epoch, SeqCst);
362    }
363
364    pub fn teardown(&self) {
365        self.connections.write().clear();
366    }
367
368    pub fn request<T: RequestMessage>(
369        &self,
370        receiver_id: ConnectionId,
371        request: T,
372    ) -> impl Future<Output = Result<T::Response>> {
373        self.request_internal(None, receiver_id, request)
374    }
375
376    pub fn forward_request<T: RequestMessage>(
377        &self,
378        sender_id: ConnectionId,
379        receiver_id: ConnectionId,
380        request: T,
381    ) -> impl Future<Output = Result<T::Response>> {
382        self.request_internal(Some(sender_id), receiver_id, request)
383    }
384
385    pub fn request_internal<T: RequestMessage>(
386        &self,
387        original_sender_id: Option<ConnectionId>,
388        receiver_id: ConnectionId,
389        request: T,
390    ) -> impl Future<Output = Result<T::Response>> {
391        let (tx, rx) = oneshot::channel();
392        let send = self.connection_state(receiver_id).and_then(|connection| {
393            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
394            connection
395                .response_channels
396                .lock()
397                .as_mut()
398                .ok_or_else(|| anyhow!("connection was closed"))?
399                .insert(message_id, tx);
400            connection
401                .outgoing_tx
402                .unbounded_send(proto::Message::Envelope(request.into_envelope(
403                    message_id,
404                    None,
405                    original_sender_id.map(Into::into),
406                )))
407                .map_err(|_| anyhow!("connection was closed"))?;
408            Ok(())
409        });
410        async move {
411            send?;
412            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
413            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
414                Err(anyhow!(
415                    "RPC request {} failed - {}",
416                    T::NAME,
417                    error.message
418                ))
419            } else {
420                T::Response::from_envelope(response)
421                    .ok_or_else(|| anyhow!("received response of the wrong type"))
422            }
423        }
424    }
425
426    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
427        let connection = self.connection_state(receiver_id)?;
428        let message_id = connection
429            .next_message_id
430            .fetch_add(1, atomic::Ordering::SeqCst);
431        connection
432            .outgoing_tx
433            .unbounded_send(proto::Message::Envelope(
434                message.into_envelope(message_id, None, None),
435            ))?;
436        Ok(())
437    }
438
439    pub fn forward_send<T: EnvelopedMessage>(
440        &self,
441        sender_id: ConnectionId,
442        receiver_id: ConnectionId,
443        message: T,
444    ) -> Result<()> {
445        let connection = self.connection_state(receiver_id)?;
446        let message_id = connection
447            .next_message_id
448            .fetch_add(1, atomic::Ordering::SeqCst);
449        connection
450            .outgoing_tx
451            .unbounded_send(proto::Message::Envelope(message.into_envelope(
452                message_id,
453                None,
454                Some(sender_id.into()),
455            )))?;
456        Ok(())
457    }
458
459    pub fn respond<T: RequestMessage>(
460        &self,
461        receipt: Receipt<T>,
462        response: T::Response,
463    ) -> Result<()> {
464        let connection = self.connection_state(receipt.sender_id)?;
465        let message_id = connection
466            .next_message_id
467            .fetch_add(1, atomic::Ordering::SeqCst);
468        connection
469            .outgoing_tx
470            .unbounded_send(proto::Message::Envelope(response.into_envelope(
471                message_id,
472                Some(receipt.message_id),
473                None,
474            )))?;
475        Ok(())
476    }
477
478    pub fn respond_with_error<T: RequestMessage>(
479        &self,
480        receipt: Receipt<T>,
481        response: proto::Error,
482    ) -> Result<()> {
483        let connection = self.connection_state(receipt.sender_id)?;
484        let message_id = connection
485            .next_message_id
486            .fetch_add(1, atomic::Ordering::SeqCst);
487        connection
488            .outgoing_tx
489            .unbounded_send(proto::Message::Envelope(response.into_envelope(
490                message_id,
491                Some(receipt.message_id),
492                None,
493            )))?;
494        Ok(())
495    }
496
497    pub fn respond_with_unhandled_message(
498        &self,
499        envelope: Box<dyn AnyTypedEnvelope>,
500    ) -> Result<()> {
501        let connection = self.connection_state(envelope.sender_id())?;
502        let response = proto::Error {
503            message: format!("message {} was not handled", envelope.payload_type_name()),
504        };
505        let message_id = connection
506            .next_message_id
507            .fetch_add(1, atomic::Ordering::SeqCst);
508        connection
509            .outgoing_tx
510            .unbounded_send(proto::Message::Envelope(response.into_envelope(
511                message_id,
512                Some(envelope.message_id()),
513                None,
514            )))?;
515        Ok(())
516    }
517
518    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
519        let connections = self.connections.read();
520        let connection = connections
521            .get(&connection_id)
522            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
523        Ok(connection.clone())
524    }
525}
526
527impl Serialize for Peer {
528    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
529    where
530        S: serde::Serializer,
531    {
532        let mut state = serializer.serialize_struct("Peer", 2)?;
533        state.serialize_field("connections", &*self.connections.read())?;
534        state.end()
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541    use crate::TypedEnvelope;
542    use async_tungstenite::tungstenite::Message as WebSocketMessage;
543    use gpui::TestAppContext;
544
545    #[ctor::ctor]
546    fn init_logger() {
547        if std::env::var("RUST_LOG").is_ok() {
548            env_logger::init();
549        }
550    }
551
552    #[gpui::test(iterations = 50)]
553    async fn test_request_response(cx: &mut TestAppContext) {
554        let executor = cx.foreground();
555
556        // create 2 clients connected to 1 server
557        let server = Peer::new(0);
558        let client1 = Peer::new(0);
559        let client2 = Peer::new(0);
560
561        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
562            Connection::in_memory(cx.background());
563        let (client1_conn_id, io_task1, client1_incoming) =
564            client1.add_test_connection(client1_to_server_conn, cx.background());
565        let (_, io_task2, server_incoming1) =
566            server.add_test_connection(server_to_client_1_conn, cx.background());
567
568        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
569            Connection::in_memory(cx.background());
570        let (client2_conn_id, io_task3, client2_incoming) =
571            client2.add_test_connection(client2_to_server_conn, cx.background());
572        let (_, io_task4, server_incoming2) =
573            server.add_test_connection(server_to_client_2_conn, cx.background());
574
575        executor.spawn(io_task1).detach();
576        executor.spawn(io_task2).detach();
577        executor.spawn(io_task3).detach();
578        executor.spawn(io_task4).detach();
579        executor
580            .spawn(handle_messages(server_incoming1, server.clone()))
581            .detach();
582        executor
583            .spawn(handle_messages(client1_incoming, client1.clone()))
584            .detach();
585        executor
586            .spawn(handle_messages(server_incoming2, server.clone()))
587            .detach();
588        executor
589            .spawn(handle_messages(client2_incoming, client2.clone()))
590            .detach();
591
592        assert_eq!(
593            client1
594                .request(client1_conn_id, proto::Ping {},)
595                .await
596                .unwrap(),
597            proto::Ack {}
598        );
599
600        assert_eq!(
601            client2
602                .request(client2_conn_id, proto::Ping {},)
603                .await
604                .unwrap(),
605            proto::Ack {}
606        );
607
608        assert_eq!(
609            client1
610                .request(client1_conn_id, proto::Test { id: 1 },)
611                .await
612                .unwrap(),
613            proto::Test { id: 1 }
614        );
615
616        assert_eq!(
617            client2
618                .request(client2_conn_id, proto::Test { id: 2 })
619                .await
620                .unwrap(),
621            proto::Test { id: 2 }
622        );
623
624        client1.disconnect(client1_conn_id);
625        client2.disconnect(client1_conn_id);
626
627        async fn handle_messages(
628            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
629            peer: Arc<Peer>,
630        ) -> Result<()> {
631            while let Some(envelope) = messages.next().await {
632                let envelope = envelope.into_any();
633                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
634                    let receipt = envelope.receipt();
635                    peer.respond(receipt, proto::Ack {})?
636                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
637                {
638                    peer.respond(envelope.receipt(), envelope.payload.clone())?
639                } else {
640                    panic!("unknown message type");
641                }
642            }
643
644            Ok(())
645        }
646    }
647
648    #[gpui::test(iterations = 50)]
649    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
650        let executor = cx.foreground();
651        let server = Peer::new(0);
652        let client = Peer::new(0);
653
654        let (client_to_server_conn, server_to_client_conn, _kill) =
655            Connection::in_memory(cx.background());
656        let (client_to_server_conn_id, io_task1, mut client_incoming) =
657            client.add_test_connection(client_to_server_conn, cx.background());
658        let (server_to_client_conn_id, io_task2, mut server_incoming) =
659            server.add_test_connection(server_to_client_conn, cx.background());
660
661        executor.spawn(io_task1).detach();
662        executor.spawn(io_task2).detach();
663
664        executor
665            .spawn(async move {
666                let request = server_incoming
667                    .next()
668                    .await
669                    .unwrap()
670                    .into_any()
671                    .downcast::<TypedEnvelope<proto::Ping>>()
672                    .unwrap();
673
674                server
675                    .send(
676                        server_to_client_conn_id,
677                        proto::Error {
678                            message: "message 1".to_string(),
679                        },
680                    )
681                    .unwrap();
682                server
683                    .send(
684                        server_to_client_conn_id,
685                        proto::Error {
686                            message: "message 2".to_string(),
687                        },
688                    )
689                    .unwrap();
690                server.respond(request.receipt(), proto::Ack {}).unwrap();
691
692                // Prevent the connection from being dropped
693                server_incoming.next().await;
694            })
695            .detach();
696
697        let events = Arc::new(Mutex::new(Vec::new()));
698
699        let response = client.request(client_to_server_conn_id, proto::Ping {});
700        let response_task = executor.spawn({
701            let events = events.clone();
702            async move {
703                response.await.unwrap();
704                events.lock().push("response".to_string());
705            }
706        });
707
708        executor
709            .spawn({
710                let events = events.clone();
711                async move {
712                    let incoming1 = client_incoming
713                        .next()
714                        .await
715                        .unwrap()
716                        .into_any()
717                        .downcast::<TypedEnvelope<proto::Error>>()
718                        .unwrap();
719                    events.lock().push(incoming1.payload.message);
720                    let incoming2 = client_incoming
721                        .next()
722                        .await
723                        .unwrap()
724                        .into_any()
725                        .downcast::<TypedEnvelope<proto::Error>>()
726                        .unwrap();
727                    events.lock().push(incoming2.payload.message);
728
729                    // Prevent the connection from being dropped
730                    client_incoming.next().await;
731                }
732            })
733            .detach();
734
735        response_task.await;
736        assert_eq!(
737            &*events.lock(),
738            &[
739                "message 1".to_string(),
740                "message 2".to_string(),
741                "response".to_string()
742            ]
743        );
744    }
745
746    #[gpui::test(iterations = 50)]
747    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
748        let executor = cx.foreground();
749        let server = Peer::new(0);
750        let client = Peer::new(0);
751
752        let (client_to_server_conn, server_to_client_conn, _kill) =
753            Connection::in_memory(cx.background());
754        let (client_to_server_conn_id, io_task1, mut client_incoming) =
755            client.add_test_connection(client_to_server_conn, cx.background());
756        let (server_to_client_conn_id, io_task2, mut server_incoming) =
757            server.add_test_connection(server_to_client_conn, cx.background());
758
759        executor.spawn(io_task1).detach();
760        executor.spawn(io_task2).detach();
761
762        executor
763            .spawn(async move {
764                let request1 = server_incoming
765                    .next()
766                    .await
767                    .unwrap()
768                    .into_any()
769                    .downcast::<TypedEnvelope<proto::Ping>>()
770                    .unwrap();
771                let request2 = server_incoming
772                    .next()
773                    .await
774                    .unwrap()
775                    .into_any()
776                    .downcast::<TypedEnvelope<proto::Ping>>()
777                    .unwrap();
778
779                server
780                    .send(
781                        server_to_client_conn_id,
782                        proto::Error {
783                            message: "message 1".to_string(),
784                        },
785                    )
786                    .unwrap();
787                server
788                    .send(
789                        server_to_client_conn_id,
790                        proto::Error {
791                            message: "message 2".to_string(),
792                        },
793                    )
794                    .unwrap();
795                server.respond(request1.receipt(), proto::Ack {}).unwrap();
796                server.respond(request2.receipt(), proto::Ack {}).unwrap();
797
798                // Prevent the connection from being dropped
799                server_incoming.next().await;
800            })
801            .detach();
802
803        let events = Arc::new(Mutex::new(Vec::new()));
804
805        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
806        let request1_task = executor.spawn(request1);
807        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
808        let request2_task = executor.spawn({
809            let events = events.clone();
810            async move {
811                request2.await.unwrap();
812                events.lock().push("response 2".to_string());
813            }
814        });
815
816        executor
817            .spawn({
818                let events = events.clone();
819                async move {
820                    let incoming1 = client_incoming
821                        .next()
822                        .await
823                        .unwrap()
824                        .into_any()
825                        .downcast::<TypedEnvelope<proto::Error>>()
826                        .unwrap();
827                    events.lock().push(incoming1.payload.message);
828                    let incoming2 = client_incoming
829                        .next()
830                        .await
831                        .unwrap()
832                        .into_any()
833                        .downcast::<TypedEnvelope<proto::Error>>()
834                        .unwrap();
835                    events.lock().push(incoming2.payload.message);
836
837                    // Prevent the connection from being dropped
838                    client_incoming.next().await;
839                }
840            })
841            .detach();
842
843        // Allow the request to make some progress before dropping it.
844        cx.background().simulate_random_delay().await;
845        drop(request1_task);
846
847        request2_task.await;
848        assert_eq!(
849            &*events.lock(),
850            &[
851                "message 1".to_string(),
852                "message 2".to_string(),
853                "response 2".to_string()
854            ]
855        );
856    }
857
858    #[gpui::test(iterations = 50)]
859    async fn test_disconnect(cx: &mut TestAppContext) {
860        let executor = cx.foreground();
861
862        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
863
864        let client = Peer::new(0);
865        let (connection_id, io_handler, mut incoming) =
866            client.add_test_connection(client_conn, cx.background());
867
868        let (io_ended_tx, io_ended_rx) = oneshot::channel();
869        executor
870            .spawn(async move {
871                io_handler.await.ok();
872                io_ended_tx.send(()).unwrap();
873            })
874            .detach();
875
876        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
877        executor
878            .spawn(async move {
879                incoming.next().await;
880                messages_ended_tx.send(()).unwrap();
881            })
882            .detach();
883
884        client.disconnect(connection_id);
885
886        let _ = io_ended_rx.await;
887        let _ = messages_ended_rx.await;
888        assert!(server_conn
889            .send(WebSocketMessage::Binary(vec![]))
890            .await
891            .is_err());
892    }
893
894    #[gpui::test(iterations = 50)]
895    async fn test_io_error(cx: &mut TestAppContext) {
896        let executor = cx.foreground();
897        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
898
899        let client = Peer::new(0);
900        let (connection_id, io_handler, mut incoming) =
901            client.add_test_connection(client_conn, cx.background());
902        executor.spawn(io_handler).detach();
903        executor
904            .spawn(async move { incoming.next().await })
905            .detach();
906
907        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
908        let _request = server_conn.rx.next().await.unwrap().unwrap();
909
910        drop(server_conn);
911        assert_eq!(
912            response.await.unwrap_err().to_string(),
913            "connection was closed"
914        );
915    }
916}