peer.rs

  1use super::{
  2    proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
  3    Connection,
  4};
  5use anyhow::{anyhow, Context, Result};
  6use collections::HashMap;
  7use futures::{
  8    channel::{mpsc, oneshot},
  9    stream::BoxStream,
 10    FutureExt, SinkExt, StreamExt,
 11};
 12use parking_lot::{Mutex, RwLock};
 13use serde::{ser::SerializeStruct, Serialize};
 14use std::sync::atomic::Ordering::SeqCst;
 15use std::{
 16    fmt,
 17    future::Future,
 18    marker::PhantomData,
 19    sync::{
 20        atomic::{self, AtomicU32},
 21        Arc,
 22    },
 23    time::Duration,
 24};
 25use tracing::instrument;
 26
 27#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
 28pub struct ConnectionId(pub u32);
 29
 30impl fmt::Display for ConnectionId {
 31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 32        self.0.fmt(f)
 33    }
 34}
 35
 36#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 37pub struct PeerId(pub u32);
 38
 39impl fmt::Display for PeerId {
 40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 41        self.0.fmt(f)
 42    }
 43}
 44
 45pub struct Receipt<T> {
 46    pub sender_id: ConnectionId,
 47    pub message_id: u32,
 48    payload_type: PhantomData<T>,
 49}
 50
 51impl<T> Clone for Receipt<T> {
 52    fn clone(&self) -> Self {
 53        Self {
 54            sender_id: self.sender_id,
 55            message_id: self.message_id,
 56            payload_type: PhantomData,
 57        }
 58    }
 59}
 60
 61impl<T> Copy for Receipt<T> {}
 62
 63pub struct TypedEnvelope<T> {
 64    pub sender_id: ConnectionId,
 65    pub original_sender_id: Option<PeerId>,
 66    pub message_id: u32,
 67    pub payload: T,
 68}
 69
 70impl<T> TypedEnvelope<T> {
 71    pub fn original_sender_id(&self) -> Result<PeerId> {
 72        self.original_sender_id
 73            .ok_or_else(|| anyhow!("missing original_sender_id"))
 74    }
 75}
 76
 77impl<T: RequestMessage> TypedEnvelope<T> {
 78    pub fn receipt(&self) -> Receipt<T> {
 79        Receipt {
 80            sender_id: self.sender_id,
 81            message_id: self.message_id,
 82            payload_type: PhantomData,
 83        }
 84    }
 85}
 86
 87pub struct Peer {
 88    pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 89    next_connection_id: AtomicU32,
 90}
 91
 92#[derive(Clone, Serialize)]
 93pub struct ConnectionState {
 94    #[serde(skip)]
 95    outgoing_tx: mpsc::UnboundedSender<proto::Message>,
 96    next_message_id: Arc<AtomicU32>,
 97    #[allow(clippy::type_complexity)]
 98    #[serde(skip)]
 99    response_channels:
100        Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
101}
102
103const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
104const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
105pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
106
107impl Peer {
108    pub fn new() -> Arc<Self> {
109        Arc::new(Self {
110            connections: Default::default(),
111            next_connection_id: Default::default(),
112        })
113    }
114
115    #[instrument(skip_all)]
116    pub async fn add_connection<F, Fut, Out>(
117        self: &Arc<Self>,
118        connection: Connection,
119        create_timer: F,
120    ) -> (
121        ConnectionId,
122        impl Future<Output = anyhow::Result<()>> + Send,
123        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
124    )
125    where
126        F: Send + Fn(Duration) -> Fut,
127        Fut: Send + Future<Output = Out>,
128        Out: Send,
129    {
130        // For outgoing messages, use an unbounded channel so that application code
131        // can always send messages without yielding. For incoming messages, use a
132        // bounded channel so that other peers will receive backpressure if they send
133        // messages faster than this peer can process them.
134        #[cfg(any(test, feature = "test-support"))]
135        const INCOMING_BUFFER_SIZE: usize = 1;
136        #[cfg(not(any(test, feature = "test-support")))]
137        const INCOMING_BUFFER_SIZE: usize = 64;
138        let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
139        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
140
141        let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
142        let connection_state = ConnectionState {
143            outgoing_tx,
144            next_message_id: Default::default(),
145            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
146        };
147        let mut writer = MessageStream::new(connection.tx);
148        let mut reader = MessageStream::new(connection.rx);
149
150        let this = self.clone();
151        let response_channels = connection_state.response_channels.clone();
152        let handle_io = async move {
153            tracing::debug!(%connection_id, "handle io future: start");
154
155            let _end_connection = util::defer(|| {
156                response_channels.lock().take();
157                this.connections.write().remove(&connection_id);
158                tracing::debug!(%connection_id, "handle io future: end");
159            });
160
161            // Send messages on this frequency so the connection isn't closed.
162            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
163            futures::pin_mut!(keepalive_timer);
164
165            // Disconnect if we don't receive messages at least this frequently.
166            let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
167            futures::pin_mut!(receive_timeout);
168
169            loop {
170                tracing::debug!(%connection_id, "outer loop iteration start");
171                let read_message = reader.read().fuse();
172                futures::pin_mut!(read_message);
173
174                loop {
175                    tracing::debug!(%connection_id, "inner loop iteration start");
176                    futures::select_biased! {
177                        outgoing = outgoing_rx.next().fuse() => match outgoing {
178                            Some(outgoing) => {
179                                tracing::debug!(%connection_id, "outgoing rpc message: writing");
180                                futures::select_biased! {
181                                    result = writer.write(outgoing).fuse() => {
182                                        tracing::debug!(%connection_id, "outgoing rpc message: done writing");
183                                        result.context("failed to write RPC message")?;
184                                        tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
185                                        keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
186                                    }
187                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
188                                        tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
189                                        Err(anyhow!("timed out writing message"))?;
190                                    }
191                                }
192                            }
193                            None => {
194                                tracing::debug!(%connection_id, "outgoing rpc message: channel closed");
195                                return Ok(())
196                            },
197                        },
198                        _ = keepalive_timer => {
199                            tracing::debug!(%connection_id, "keepalive interval: pinging");
200                            futures::select_biased! {
201                                result = writer.write(proto::Message::Ping).fuse() => {
202                                    tracing::debug!(%connection_id, "keepalive interval: done pinging");
203                                    result.context("failed to send keepalive")?;
204                                    tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
205                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
206                                }
207                                _ = create_timer(WRITE_TIMEOUT).fuse() => {
208                                    tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
209                                    Err(anyhow!("timed out sending keepalive"))?;
210                                }
211                            }
212                        }
213                        incoming = read_message => {
214                            let incoming = incoming.context("error reading rpc message from socket")?;
215                            tracing::debug!(%connection_id, "incoming rpc message: received");
216                            tracing::debug!(%connection_id, "receive timeout: resetting");
217                            receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
218                            if let proto::Message::Envelope(incoming) = incoming {
219                                tracing::debug!(%connection_id, "incoming rpc message: processing");
220                                futures::select_biased! {
221                                    result = incoming_tx.send(incoming).fuse() => match result {
222                                        Ok(_) => {
223                                            tracing::debug!(%connection_id, "incoming rpc message: processed");
224                                        }
225                                        Err(_) => {
226                                            tracing::debug!(%connection_id, "incoming rpc message: channel closed");
227                                            return Ok(())
228                                        }
229                                    },
230                                    _ = create_timer(WRITE_TIMEOUT).fuse() => {
231                                        tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
232                                        Err(anyhow!("timed out processing incoming message"))?
233                                    }
234                                }
235                            }
236                            break;
237                        },
238                        _ = receive_timeout => {
239                            tracing::debug!(%connection_id, "receive timeout: delay between messages too long");
240                            Err(anyhow!("delay between messages too long"))?
241                        }
242                    }
243                }
244            }
245        };
246
247        let response_channels = connection_state.response_channels.clone();
248        self.connections
249            .write()
250            .insert(connection_id, connection_state);
251
252        let incoming_rx = incoming_rx.filter_map(move |incoming| {
253            let response_channels = response_channels.clone();
254            async move {
255                let message_id = incoming.id;
256                tracing::debug!(?incoming, "incoming message future: start");
257                let _end = util::defer(move || {
258                    tracing::debug!(
259                        %connection_id,
260                        message_id,
261                        "incoming message future: end"
262                    );
263                });
264
265                if let Some(responding_to) = incoming.responding_to {
266                    tracing::debug!(
267                        %connection_id,
268                        message_id,
269                        responding_to,
270                        "incoming response: received"
271                    );
272                    let channel = response_channels.lock().as_mut()?.remove(&responding_to);
273                    if let Some(tx) = channel {
274                        let requester_resumed = oneshot::channel();
275                        if let Err(error) = tx.send((incoming, requester_resumed.0)) {
276                            tracing::debug!(
277                                %connection_id,
278                                message_id,
279                                responding_to = responding_to,
280                                ?error,
281                                "incoming response: request future dropped",
282                            );
283                        }
284
285                        tracing::debug!(
286                            %connection_id,
287                            message_id,
288                            responding_to,
289                            "incoming response: waiting to resume requester"
290                        );
291                        let _ = requester_resumed.1.await;
292                        tracing::debug!(
293                            %connection_id,
294                            message_id,
295                            responding_to,
296                            "incoming response: requester resumed"
297                        );
298                    } else {
299                        tracing::warn!(
300                            %connection_id,
301                            message_id,
302                            responding_to,
303                            "incoming response: unknown request"
304                        );
305                    }
306
307                    None
308                } else {
309                    tracing::debug!(
310                        %connection_id,
311                        message_id,
312                        "incoming message: received"
313                    );
314                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
315                        tracing::error!(
316                            %connection_id,
317                            message_id,
318                            "unable to construct a typed envelope"
319                        );
320                        None
321                    })
322                }
323            }
324        });
325        (connection_id, handle_io, incoming_rx.boxed())
326    }
327
328    #[cfg(any(test, feature = "test-support"))]
329    pub async fn add_test_connection(
330        self: &Arc<Self>,
331        connection: Connection,
332        executor: Arc<gpui::executor::Background>,
333    ) -> (
334        ConnectionId,
335        impl Future<Output = anyhow::Result<()>> + Send,
336        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
337    ) {
338        let executor = executor.clone();
339        self.add_connection(connection, move |duration| executor.timer(duration))
340            .await
341    }
342
343    pub fn disconnect(&self, connection_id: ConnectionId) {
344        self.connections.write().remove(&connection_id);
345    }
346
347    pub fn reset(&self) {
348        self.connections.write().clear();
349    }
350
351    pub fn request<T: RequestMessage>(
352        &self,
353        receiver_id: ConnectionId,
354        request: T,
355    ) -> impl Future<Output = Result<T::Response>> {
356        self.request_internal(None, receiver_id, request)
357    }
358
359    pub fn forward_request<T: RequestMessage>(
360        &self,
361        sender_id: ConnectionId,
362        receiver_id: ConnectionId,
363        request: T,
364    ) -> impl Future<Output = Result<T::Response>> {
365        self.request_internal(Some(sender_id), receiver_id, request)
366    }
367
368    pub fn request_internal<T: RequestMessage>(
369        &self,
370        original_sender_id: Option<ConnectionId>,
371        receiver_id: ConnectionId,
372        request: T,
373    ) -> impl Future<Output = Result<T::Response>> {
374        let (tx, rx) = oneshot::channel();
375        let send = self.connection_state(receiver_id).and_then(|connection| {
376            let message_id = connection.next_message_id.fetch_add(1, SeqCst);
377            connection
378                .response_channels
379                .lock()
380                .as_mut()
381                .ok_or_else(|| anyhow!("connection was closed"))?
382                .insert(message_id, tx);
383            connection
384                .outgoing_tx
385                .unbounded_send(proto::Message::Envelope(request.into_envelope(
386                    message_id,
387                    None,
388                    original_sender_id.map(|id| id.0),
389                )))
390                .map_err(|_| anyhow!("connection was closed"))?;
391            Ok(())
392        });
393        async move {
394            send?;
395            let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
396            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
397                Err(anyhow!(
398                    "RPC request {} failed - {}",
399                    T::NAME,
400                    error.message
401                ))
402            } else {
403                T::Response::from_envelope(response)
404                    .ok_or_else(|| anyhow!("received response of the wrong type"))
405            }
406        }
407    }
408
409    pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
410        let connection = self.connection_state(receiver_id)?;
411        let message_id = connection
412            .next_message_id
413            .fetch_add(1, atomic::Ordering::SeqCst);
414        connection
415            .outgoing_tx
416            .unbounded_send(proto::Message::Envelope(
417                message.into_envelope(message_id, None, None),
418            ))?;
419        Ok(())
420    }
421
422    pub fn forward_send<T: EnvelopedMessage>(
423        &self,
424        sender_id: ConnectionId,
425        receiver_id: ConnectionId,
426        message: T,
427    ) -> Result<()> {
428        let connection = self.connection_state(receiver_id)?;
429        let message_id = connection
430            .next_message_id
431            .fetch_add(1, atomic::Ordering::SeqCst);
432        connection
433            .outgoing_tx
434            .unbounded_send(proto::Message::Envelope(message.into_envelope(
435                message_id,
436                None,
437                Some(sender_id.0),
438            )))?;
439        Ok(())
440    }
441
442    pub fn respond<T: RequestMessage>(
443        &self,
444        receipt: Receipt<T>,
445        response: T::Response,
446    ) -> Result<()> {
447        let connection = self.connection_state(receipt.sender_id)?;
448        let message_id = connection
449            .next_message_id
450            .fetch_add(1, atomic::Ordering::SeqCst);
451        connection
452            .outgoing_tx
453            .unbounded_send(proto::Message::Envelope(response.into_envelope(
454                message_id,
455                Some(receipt.message_id),
456                None,
457            )))?;
458        Ok(())
459    }
460
461    pub fn respond_with_error<T: RequestMessage>(
462        &self,
463        receipt: Receipt<T>,
464        response: proto::Error,
465    ) -> Result<()> {
466        let connection = self.connection_state(receipt.sender_id)?;
467        let message_id = connection
468            .next_message_id
469            .fetch_add(1, atomic::Ordering::SeqCst);
470        connection
471            .outgoing_tx
472            .unbounded_send(proto::Message::Envelope(response.into_envelope(
473                message_id,
474                Some(receipt.message_id),
475                None,
476            )))?;
477        Ok(())
478    }
479
480    fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
481        let connections = self.connections.read();
482        let connection = connections
483            .get(&connection_id)
484            .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
485        Ok(connection.clone())
486    }
487}
488
489impl Serialize for Peer {
490    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
491    where
492        S: serde::Serializer,
493    {
494        let mut state = serializer.serialize_struct("Peer", 2)?;
495        state.serialize_field("connections", &*self.connections.read())?;
496        state.end()
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use crate::TypedEnvelope;
504    use async_tungstenite::tungstenite::Message as WebSocketMessage;
505    use gpui::TestAppContext;
506
507    #[ctor::ctor]
508    fn init_logger() {
509        if std::env::var("RUST_LOG").is_ok() {
510            env_logger::init();
511        }
512    }
513
514    #[gpui::test(iterations = 50)]
515    async fn test_request_response(cx: &mut TestAppContext) {
516        let executor = cx.foreground();
517
518        // create 2 clients connected to 1 server
519        let server = Peer::new();
520        let client1 = Peer::new();
521        let client2 = Peer::new();
522
523        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
524            Connection::in_memory(cx.background());
525        let (client1_conn_id, io_task1, client1_incoming) = client1
526            .add_test_connection(client1_to_server_conn, cx.background())
527            .await;
528        let (_, io_task2, server_incoming1) = server
529            .add_test_connection(server_to_client_1_conn, cx.background())
530            .await;
531
532        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
533            Connection::in_memory(cx.background());
534        let (client2_conn_id, io_task3, client2_incoming) = client2
535            .add_test_connection(client2_to_server_conn, cx.background())
536            .await;
537        let (_, io_task4, server_incoming2) = server
538            .add_test_connection(server_to_client_2_conn, cx.background())
539            .await;
540
541        executor.spawn(io_task1).detach();
542        executor.spawn(io_task2).detach();
543        executor.spawn(io_task3).detach();
544        executor.spawn(io_task4).detach();
545        executor
546            .spawn(handle_messages(server_incoming1, server.clone()))
547            .detach();
548        executor
549            .spawn(handle_messages(client1_incoming, client1.clone()))
550            .detach();
551        executor
552            .spawn(handle_messages(server_incoming2, server.clone()))
553            .detach();
554        executor
555            .spawn(handle_messages(client2_incoming, client2.clone()))
556            .detach();
557
558        assert_eq!(
559            client1
560                .request(client1_conn_id, proto::Ping {},)
561                .await
562                .unwrap(),
563            proto::Ack {}
564        );
565
566        assert_eq!(
567            client2
568                .request(client2_conn_id, proto::Ping {},)
569                .await
570                .unwrap(),
571            proto::Ack {}
572        );
573
574        assert_eq!(
575            client1
576                .request(client1_conn_id, proto::Test { id: 1 },)
577                .await
578                .unwrap(),
579            proto::Test { id: 1 }
580        );
581
582        assert_eq!(
583            client2
584                .request(client2_conn_id, proto::Test { id: 2 })
585                .await
586                .unwrap(),
587            proto::Test { id: 2 }
588        );
589
590        client1.disconnect(client1_conn_id);
591        client2.disconnect(client1_conn_id);
592
593        async fn handle_messages(
594            mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
595            peer: Arc<Peer>,
596        ) -> Result<()> {
597            while let Some(envelope) = messages.next().await {
598                let envelope = envelope.into_any();
599                if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
600                    let receipt = envelope.receipt();
601                    peer.respond(receipt, proto::Ack {})?
602                } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
603                {
604                    peer.respond(envelope.receipt(), envelope.payload.clone())?
605                } else {
606                    panic!("unknown message type");
607                }
608            }
609
610            Ok(())
611        }
612    }
613
614    #[gpui::test(iterations = 50)]
615    async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
616        let executor = cx.foreground();
617        let server = Peer::new();
618        let client = Peer::new();
619
620        let (client_to_server_conn, server_to_client_conn, _kill) =
621            Connection::in_memory(cx.background());
622        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
623            .add_test_connection(client_to_server_conn, cx.background())
624            .await;
625        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
626            .add_test_connection(server_to_client_conn, cx.background())
627            .await;
628
629        executor.spawn(io_task1).detach();
630        executor.spawn(io_task2).detach();
631
632        executor
633            .spawn(async move {
634                let request = server_incoming
635                    .next()
636                    .await
637                    .unwrap()
638                    .into_any()
639                    .downcast::<TypedEnvelope<proto::Ping>>()
640                    .unwrap();
641
642                server
643                    .send(
644                        server_to_client_conn_id,
645                        proto::Error {
646                            message: "message 1".to_string(),
647                        },
648                    )
649                    .unwrap();
650                server
651                    .send(
652                        server_to_client_conn_id,
653                        proto::Error {
654                            message: "message 2".to_string(),
655                        },
656                    )
657                    .unwrap();
658                server.respond(request.receipt(), proto::Ack {}).unwrap();
659
660                // Prevent the connection from being dropped
661                server_incoming.next().await;
662            })
663            .detach();
664
665        let events = Arc::new(Mutex::new(Vec::new()));
666
667        let response = client.request(client_to_server_conn_id, proto::Ping {});
668        let response_task = executor.spawn({
669            let events = events.clone();
670            async move {
671                response.await.unwrap();
672                events.lock().push("response".to_string());
673            }
674        });
675
676        executor
677            .spawn({
678                let events = events.clone();
679                async move {
680                    let incoming1 = client_incoming
681                        .next()
682                        .await
683                        .unwrap()
684                        .into_any()
685                        .downcast::<TypedEnvelope<proto::Error>>()
686                        .unwrap();
687                    events.lock().push(incoming1.payload.message);
688                    let incoming2 = client_incoming
689                        .next()
690                        .await
691                        .unwrap()
692                        .into_any()
693                        .downcast::<TypedEnvelope<proto::Error>>()
694                        .unwrap();
695                    events.lock().push(incoming2.payload.message);
696
697                    // Prevent the connection from being dropped
698                    client_incoming.next().await;
699                }
700            })
701            .detach();
702
703        response_task.await;
704        assert_eq!(
705            &*events.lock(),
706            &[
707                "message 1".to_string(),
708                "message 2".to_string(),
709                "response".to_string()
710            ]
711        );
712    }
713
714    #[gpui::test(iterations = 50)]
715    async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
716        let executor = cx.foreground();
717        let server = Peer::new();
718        let client = Peer::new();
719
720        let (client_to_server_conn, server_to_client_conn, _kill) =
721            Connection::in_memory(cx.background());
722        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
723            .add_test_connection(client_to_server_conn, cx.background())
724            .await;
725        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
726            .add_test_connection(server_to_client_conn, cx.background())
727            .await;
728
729        executor.spawn(io_task1).detach();
730        executor.spawn(io_task2).detach();
731
732        executor
733            .spawn(async move {
734                let request1 = server_incoming
735                    .next()
736                    .await
737                    .unwrap()
738                    .into_any()
739                    .downcast::<TypedEnvelope<proto::Ping>>()
740                    .unwrap();
741                let request2 = server_incoming
742                    .next()
743                    .await
744                    .unwrap()
745                    .into_any()
746                    .downcast::<TypedEnvelope<proto::Ping>>()
747                    .unwrap();
748
749                server
750                    .send(
751                        server_to_client_conn_id,
752                        proto::Error {
753                            message: "message 1".to_string(),
754                        },
755                    )
756                    .unwrap();
757                server
758                    .send(
759                        server_to_client_conn_id,
760                        proto::Error {
761                            message: "message 2".to_string(),
762                        },
763                    )
764                    .unwrap();
765                server.respond(request1.receipt(), proto::Ack {}).unwrap();
766                server.respond(request2.receipt(), proto::Ack {}).unwrap();
767
768                // Prevent the connection from being dropped
769                server_incoming.next().await;
770            })
771            .detach();
772
773        let events = Arc::new(Mutex::new(Vec::new()));
774
775        let request1 = client.request(client_to_server_conn_id, proto::Ping {});
776        let request1_task = executor.spawn(request1);
777        let request2 = client.request(client_to_server_conn_id, proto::Ping {});
778        let request2_task = executor.spawn({
779            let events = events.clone();
780            async move {
781                request2.await.unwrap();
782                events.lock().push("response 2".to_string());
783            }
784        });
785
786        executor
787            .spawn({
788                let events = events.clone();
789                async move {
790                    let incoming1 = client_incoming
791                        .next()
792                        .await
793                        .unwrap()
794                        .into_any()
795                        .downcast::<TypedEnvelope<proto::Error>>()
796                        .unwrap();
797                    events.lock().push(incoming1.payload.message);
798                    let incoming2 = client_incoming
799                        .next()
800                        .await
801                        .unwrap()
802                        .into_any()
803                        .downcast::<TypedEnvelope<proto::Error>>()
804                        .unwrap();
805                    events.lock().push(incoming2.payload.message);
806
807                    // Prevent the connection from being dropped
808                    client_incoming.next().await;
809                }
810            })
811            .detach();
812
813        // Allow the request to make some progress before dropping it.
814        cx.background().simulate_random_delay().await;
815        drop(request1_task);
816
817        request2_task.await;
818        assert_eq!(
819            &*events.lock(),
820            &[
821                "message 1".to_string(),
822                "message 2".to_string(),
823                "response 2".to_string()
824            ]
825        );
826    }
827
828    #[gpui::test(iterations = 50)]
829    async fn test_disconnect(cx: &mut TestAppContext) {
830        let executor = cx.foreground();
831
832        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
833
834        let client = Peer::new();
835        let (connection_id, io_handler, mut incoming) = client
836            .add_test_connection(client_conn, cx.background())
837            .await;
838
839        let (io_ended_tx, io_ended_rx) = oneshot::channel();
840        executor
841            .spawn(async move {
842                io_handler.await.ok();
843                io_ended_tx.send(()).unwrap();
844            })
845            .detach();
846
847        let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
848        executor
849            .spawn(async move {
850                incoming.next().await;
851                messages_ended_tx.send(()).unwrap();
852            })
853            .detach();
854
855        client.disconnect(connection_id);
856
857        let _ = io_ended_rx.await;
858        let _ = messages_ended_rx.await;
859        assert!(server_conn
860            .send(WebSocketMessage::Binary(vec![]))
861            .await
862            .is_err());
863    }
864
865    #[gpui::test(iterations = 50)]
866    async fn test_io_error(cx: &mut TestAppContext) {
867        let executor = cx.foreground();
868        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
869
870        let client = Peer::new();
871        let (connection_id, io_handler, mut incoming) = client
872            .add_test_connection(client_conn, cx.background())
873            .await;
874        executor.spawn(io_handler).detach();
875        executor
876            .spawn(async move { incoming.next().await })
877            .detach();
878
879        let response = executor.spawn(client.request(connection_id, proto::Ping {}));
880        let _request = server_conn.rx.next().await.unwrap().unwrap();
881
882        drop(server_conn);
883        assert_eq!(
884            response.await.unwrap_err().to_string(),
885            "connection was closed"
886        );
887    }
888}