1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::BoxStream,
10 FutureExt, SinkExt, StreamExt,
11};
12use parking_lot::{Mutex, RwLock};
13use serde::{ser::SerializeStruct, Serialize};
14use std::sync::atomic::Ordering::SeqCst;
15use std::{
16 fmt,
17 future::Future,
18 marker::PhantomData,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23 time::Duration,
24};
25use tracing::instrument;
26
27#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Serialize)]
28pub struct ConnectionId(pub u32);
29
30impl fmt::Display for ConnectionId {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 self.0.fmt(f)
33 }
34}
35
36#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
37pub struct PeerId(pub u32);
38
39impl fmt::Display for PeerId {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 self.0.fmt(f)
42 }
43}
44
45pub struct Receipt<T> {
46 pub sender_id: ConnectionId,
47 pub message_id: u32,
48 payload_type: PhantomData<T>,
49}
50
51impl<T> Clone for Receipt<T> {
52 fn clone(&self) -> Self {
53 Self {
54 sender_id: self.sender_id,
55 message_id: self.message_id,
56 payload_type: PhantomData,
57 }
58 }
59}
60
61impl<T> Copy for Receipt<T> {}
62
63pub struct TypedEnvelope<T> {
64 pub sender_id: ConnectionId,
65 pub original_sender_id: Option<PeerId>,
66 pub message_id: u32,
67 pub payload: T,
68}
69
70impl<T> TypedEnvelope<T> {
71 pub fn original_sender_id(&self) -> Result<PeerId> {
72 self.original_sender_id
73 .ok_or_else(|| anyhow!("missing original_sender_id"))
74 }
75}
76
77impl<T: RequestMessage> TypedEnvelope<T> {
78 pub fn receipt(&self) -> Receipt<T> {
79 Receipt {
80 sender_id: self.sender_id,
81 message_id: self.message_id,
82 payload_type: PhantomData,
83 }
84 }
85}
86
87pub struct Peer {
88 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
89 next_connection_id: AtomicU32,
90}
91
92#[derive(Clone, Serialize)]
93pub struct ConnectionState {
94 #[serde(skip)]
95 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
96 next_message_id: Arc<AtomicU32>,
97 #[serde(skip)]
98 response_channels:
99 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
100}
101
102const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
103const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
104pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
105
106impl Peer {
107 pub fn new() -> Arc<Self> {
108 Arc::new(Self {
109 connections: Default::default(),
110 next_connection_id: Default::default(),
111 })
112 }
113
114 #[instrument(skip_all)]
115 pub async fn add_connection<F, Fut, Out>(
116 self: &Arc<Self>,
117 connection: Connection,
118 create_timer: F,
119 ) -> (
120 ConnectionId,
121 impl Future<Output = anyhow::Result<()>> + Send,
122 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
123 )
124 where
125 F: Send + Fn(Duration) -> Fut,
126 Fut: Send + Future<Output = Out>,
127 Out: Send,
128 {
129 // For outgoing messages, use an unbounded channel so that application code
130 // can always send messages without yielding. For incoming messages, use a
131 // bounded channel so that other peers will receive backpressure if they send
132 // messages faster than this peer can process them.
133 #[cfg(any(test, feature = "test-support"))]
134 const INCOMING_BUFFER_SIZE: usize = 1;
135 #[cfg(not(any(test, feature = "test-support")))]
136 const INCOMING_BUFFER_SIZE: usize = 64;
137 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
138 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
139
140 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
141 let connection_state = ConnectionState {
142 outgoing_tx: outgoing_tx.clone(),
143 next_message_id: Default::default(),
144 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
145 };
146 let mut writer = MessageStream::new(connection.tx);
147 let mut reader = MessageStream::new(connection.rx);
148
149 let this = self.clone();
150 let response_channels = connection_state.response_channels.clone();
151 let handle_io = async move {
152 tracing::debug!(%connection_id, "handle io future: start");
153
154 let _end_connection = util::defer(|| {
155 response_channels.lock().take();
156 this.connections.write().remove(&connection_id);
157 tracing::debug!(%connection_id, "handle io future: end");
158 });
159
160 // Send messages on this frequency so the connection isn't closed.
161 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
162 futures::pin_mut!(keepalive_timer);
163
164 // Disconnect if we don't receive messages at least this frequently.
165 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
166 futures::pin_mut!(receive_timeout);
167
168 loop {
169 tracing::debug!(%connection_id, "outer loop iteration start");
170 let read_message = reader.read().fuse();
171 futures::pin_mut!(read_message);
172
173 loop {
174 tracing::debug!(%connection_id, "inner loop iteration start");
175 futures::select_biased! {
176 outgoing = outgoing_rx.next().fuse() => match outgoing {
177 Some(outgoing) => {
178 tracing::debug!(%connection_id, "outgoing rpc message: writing");
179 futures::select_biased! {
180 result = writer.write(outgoing).fuse() => {
181 tracing::debug!(%connection_id, "outgoing rpc message: done writing");
182 result.context("failed to write RPC message")?;
183 tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
184 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
185 }
186 _ = create_timer(WRITE_TIMEOUT).fuse() => {
187 tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
188 Err(anyhow!("timed out writing message"))?;
189 }
190 }
191 }
192 None => {
193 tracing::debug!(%connection_id, "outgoing rpc message: channel closed");
194 return Ok(())
195 },
196 },
197 _ = keepalive_timer => {
198 tracing::debug!(%connection_id, "keepalive interval: pinging");
199 futures::select_biased! {
200 result = writer.write(proto::Message::Ping).fuse() => {
201 tracing::debug!(%connection_id, "keepalive interval: done pinging");
202 result.context("failed to send keepalive")?;
203 tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
204 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
205 }
206 _ = create_timer(WRITE_TIMEOUT).fuse() => {
207 tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
208 Err(anyhow!("timed out sending keepalive"))?;
209 }
210 }
211 }
212 incoming = read_message => {
213 let incoming = incoming.context("error reading rpc message from socket")?;
214 tracing::debug!(%connection_id, "incoming rpc message: received");
215 tracing::debug!(%connection_id, "receive timeout: resetting");
216 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
217 if let proto::Message::Envelope(incoming) = incoming {
218 tracing::debug!(%connection_id, "incoming rpc message: processing");
219 futures::select_biased! {
220 result = incoming_tx.send(incoming).fuse() => match result {
221 Ok(_) => {
222 tracing::debug!(%connection_id, "incoming rpc message: processed");
223 }
224 Err(_) => {
225 tracing::debug!(%connection_id, "incoming rpc message: channel closed");
226 return Ok(())
227 }
228 },
229 _ = create_timer(WRITE_TIMEOUT).fuse() => {
230 tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
231 Err(anyhow!("timed out processing incoming message"))?
232 }
233 }
234 }
235 break;
236 },
237 _ = receive_timeout => {
238 tracing::debug!(%connection_id, "receive timeout: delay between messages too long");
239 Err(anyhow!("delay between messages too long"))?
240 }
241 }
242 }
243 }
244 };
245
246 let response_channels = connection_state.response_channels.clone();
247 self.connections
248 .write()
249 .insert(connection_id, connection_state);
250
251 let incoming_rx = incoming_rx.filter_map(move |incoming| {
252 let response_channels = response_channels.clone();
253 async move {
254 let message_id = incoming.id;
255 tracing::debug!(?incoming, "incoming message future: start");
256 let _end = util::defer(move || {
257 tracing::debug!(
258 %connection_id,
259 message_id,
260 "incoming message future: end"
261 );
262 });
263
264 if let Some(responding_to) = incoming.responding_to {
265 tracing::debug!(
266 %connection_id,
267 message_id,
268 responding_to,
269 "incoming response: received"
270 );
271 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
272 if let Some(tx) = channel {
273 let requester_resumed = oneshot::channel();
274 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
275 tracing::debug!(
276 %connection_id,
277 message_id,
278 responding_to = responding_to,
279 ?error,
280 "incoming response: request future dropped",
281 );
282 }
283
284 tracing::debug!(
285 %connection_id,
286 message_id,
287 responding_to,
288 "incoming response: waiting to resume requester"
289 );
290 let _ = requester_resumed.1.await;
291 tracing::debug!(
292 %connection_id,
293 message_id,
294 responding_to,
295 "incoming response: requester resumed"
296 );
297 } else {
298 tracing::warn!(
299 %connection_id,
300 message_id,
301 responding_to,
302 "incoming response: unknown request"
303 );
304 }
305
306 None
307 } else {
308 tracing::debug!(
309 %connection_id,
310 message_id,
311 "incoming message: received"
312 );
313 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
314 tracing::error!(
315 %connection_id,
316 message_id,
317 "unable to construct a typed envelope"
318 );
319 None
320 })
321 }
322 }
323 });
324 (connection_id, handle_io, incoming_rx.boxed())
325 }
326
327 #[cfg(any(test, feature = "test-support"))]
328 pub async fn add_test_connection(
329 self: &Arc<Self>,
330 connection: Connection,
331 executor: Arc<gpui::executor::Background>,
332 ) -> (
333 ConnectionId,
334 impl Future<Output = anyhow::Result<()>> + Send,
335 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
336 ) {
337 let executor = executor.clone();
338 self.add_connection(connection, move |duration| executor.timer(duration))
339 .await
340 }
341
342 pub fn disconnect(&self, connection_id: ConnectionId) {
343 self.connections.write().remove(&connection_id);
344 }
345
346 pub fn reset(&self) {
347 self.connections.write().clear();
348 }
349
350 pub fn request<T: RequestMessage>(
351 &self,
352 receiver_id: ConnectionId,
353 request: T,
354 ) -> impl Future<Output = Result<T::Response>> {
355 self.request_internal(None, receiver_id, request)
356 }
357
358 pub fn forward_request<T: RequestMessage>(
359 &self,
360 sender_id: ConnectionId,
361 receiver_id: ConnectionId,
362 request: T,
363 ) -> impl Future<Output = Result<T::Response>> {
364 self.request_internal(Some(sender_id), receiver_id, request)
365 }
366
367 pub fn request_internal<T: RequestMessage>(
368 &self,
369 original_sender_id: Option<ConnectionId>,
370 receiver_id: ConnectionId,
371 request: T,
372 ) -> impl Future<Output = Result<T::Response>> {
373 let (tx, rx) = oneshot::channel();
374 let send = self.connection_state(receiver_id).and_then(|connection| {
375 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
376 connection
377 .response_channels
378 .lock()
379 .as_mut()
380 .ok_or_else(|| anyhow!("connection was closed"))?
381 .insert(message_id, tx);
382 connection
383 .outgoing_tx
384 .unbounded_send(proto::Message::Envelope(request.into_envelope(
385 message_id,
386 None,
387 original_sender_id.map(|id| id.0),
388 )))
389 .map_err(|_| anyhow!("connection was closed"))?;
390 Ok(())
391 });
392 async move {
393 send?;
394 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
395 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
396 Err(anyhow!("RPC request failed - {}", error.message))
397 } else {
398 T::Response::from_envelope(response)
399 .ok_or_else(|| anyhow!("received response of the wrong type"))
400 }
401 }
402 }
403
404 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
405 let connection = self.connection_state(receiver_id)?;
406 let message_id = connection
407 .next_message_id
408 .fetch_add(1, atomic::Ordering::SeqCst);
409 connection
410 .outgoing_tx
411 .unbounded_send(proto::Message::Envelope(
412 message.into_envelope(message_id, None, None),
413 ))?;
414 Ok(())
415 }
416
417 pub fn forward_send<T: EnvelopedMessage>(
418 &self,
419 sender_id: ConnectionId,
420 receiver_id: ConnectionId,
421 message: T,
422 ) -> Result<()> {
423 let connection = self.connection_state(receiver_id)?;
424 let message_id = connection
425 .next_message_id
426 .fetch_add(1, atomic::Ordering::SeqCst);
427 connection
428 .outgoing_tx
429 .unbounded_send(proto::Message::Envelope(message.into_envelope(
430 message_id,
431 None,
432 Some(sender_id.0),
433 )))?;
434 Ok(())
435 }
436
437 pub fn respond<T: RequestMessage>(
438 &self,
439 receipt: Receipt<T>,
440 response: T::Response,
441 ) -> Result<()> {
442 let connection = self.connection_state(receipt.sender_id)?;
443 let message_id = connection
444 .next_message_id
445 .fetch_add(1, atomic::Ordering::SeqCst);
446 connection
447 .outgoing_tx
448 .unbounded_send(proto::Message::Envelope(response.into_envelope(
449 message_id,
450 Some(receipt.message_id),
451 None,
452 )))?;
453 Ok(())
454 }
455
456 pub fn respond_with_error<T: RequestMessage>(
457 &self,
458 receipt: Receipt<T>,
459 response: proto::Error,
460 ) -> Result<()> {
461 let connection = self.connection_state(receipt.sender_id)?;
462 let message_id = connection
463 .next_message_id
464 .fetch_add(1, atomic::Ordering::SeqCst);
465 connection
466 .outgoing_tx
467 .unbounded_send(proto::Message::Envelope(response.into_envelope(
468 message_id,
469 Some(receipt.message_id),
470 None,
471 )))?;
472 Ok(())
473 }
474
475 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
476 let connections = self.connections.read();
477 let connection = connections
478 .get(&connection_id)
479 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
480 Ok(connection.clone())
481 }
482}
483
484impl Serialize for Peer {
485 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
486 where
487 S: serde::Serializer,
488 {
489 let mut state = serializer.serialize_struct("Peer", 2)?;
490 state.serialize_field("connections", &*self.connections.read())?;
491 state.end()
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use crate::TypedEnvelope;
499 use async_tungstenite::tungstenite::Message as WebSocketMessage;
500 use gpui::TestAppContext;
501
502 #[ctor::ctor]
503 fn init_logger() {
504 if std::env::var("RUST_LOG").is_ok() {
505 env_logger::init();
506 }
507 }
508
509 #[gpui::test(iterations = 50)]
510 async fn test_request_response(cx: &mut TestAppContext) {
511 let executor = cx.foreground();
512
513 // create 2 clients connected to 1 server
514 let server = Peer::new();
515 let client1 = Peer::new();
516 let client2 = Peer::new();
517
518 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
519 Connection::in_memory(cx.background());
520 let (client1_conn_id, io_task1, client1_incoming) = client1
521 .add_test_connection(client1_to_server_conn, cx.background())
522 .await;
523 let (_, io_task2, server_incoming1) = server
524 .add_test_connection(server_to_client_1_conn, cx.background())
525 .await;
526
527 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
528 Connection::in_memory(cx.background());
529 let (client2_conn_id, io_task3, client2_incoming) = client2
530 .add_test_connection(client2_to_server_conn, cx.background())
531 .await;
532 let (_, io_task4, server_incoming2) = server
533 .add_test_connection(server_to_client_2_conn, cx.background())
534 .await;
535
536 executor.spawn(io_task1).detach();
537 executor.spawn(io_task2).detach();
538 executor.spawn(io_task3).detach();
539 executor.spawn(io_task4).detach();
540 executor
541 .spawn(handle_messages(server_incoming1, server.clone()))
542 .detach();
543 executor
544 .spawn(handle_messages(client1_incoming, client1.clone()))
545 .detach();
546 executor
547 .spawn(handle_messages(server_incoming2, server.clone()))
548 .detach();
549 executor
550 .spawn(handle_messages(client2_incoming, client2.clone()))
551 .detach();
552
553 assert_eq!(
554 client1
555 .request(client1_conn_id, proto::Ping {},)
556 .await
557 .unwrap(),
558 proto::Ack {}
559 );
560
561 assert_eq!(
562 client2
563 .request(client2_conn_id, proto::Ping {},)
564 .await
565 .unwrap(),
566 proto::Ack {}
567 );
568
569 assert_eq!(
570 client1
571 .request(client1_conn_id, proto::Test { id: 1 },)
572 .await
573 .unwrap(),
574 proto::Test { id: 1 }
575 );
576
577 assert_eq!(
578 client2
579 .request(client2_conn_id, proto::Test { id: 2 })
580 .await
581 .unwrap(),
582 proto::Test { id: 2 }
583 );
584
585 client1.disconnect(client1_conn_id);
586 client2.disconnect(client1_conn_id);
587
588 async fn handle_messages(
589 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
590 peer: Arc<Peer>,
591 ) -> Result<()> {
592 while let Some(envelope) = messages.next().await {
593 let envelope = envelope.into_any();
594 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
595 let receipt = envelope.receipt();
596 peer.respond(receipt, proto::Ack {})?
597 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
598 {
599 peer.respond(envelope.receipt(), envelope.payload.clone())?
600 } else {
601 panic!("unknown message type");
602 }
603 }
604
605 Ok(())
606 }
607 }
608
609 #[gpui::test(iterations = 50)]
610 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
611 let executor = cx.foreground();
612 let server = Peer::new();
613 let client = Peer::new();
614
615 let (client_to_server_conn, server_to_client_conn, _kill) =
616 Connection::in_memory(cx.background());
617 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
618 .add_test_connection(client_to_server_conn, cx.background())
619 .await;
620 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
621 .add_test_connection(server_to_client_conn, cx.background())
622 .await;
623
624 executor.spawn(io_task1).detach();
625 executor.spawn(io_task2).detach();
626
627 executor
628 .spawn(async move {
629 let request = server_incoming
630 .next()
631 .await
632 .unwrap()
633 .into_any()
634 .downcast::<TypedEnvelope<proto::Ping>>()
635 .unwrap();
636
637 server
638 .send(
639 server_to_client_conn_id,
640 proto::Error {
641 message: "message 1".to_string(),
642 },
643 )
644 .unwrap();
645 server
646 .send(
647 server_to_client_conn_id,
648 proto::Error {
649 message: "message 2".to_string(),
650 },
651 )
652 .unwrap();
653 server.respond(request.receipt(), proto::Ack {}).unwrap();
654
655 // Prevent the connection from being dropped
656 server_incoming.next().await;
657 })
658 .detach();
659
660 let events = Arc::new(Mutex::new(Vec::new()));
661
662 let response = client.request(client_to_server_conn_id, proto::Ping {});
663 let response_task = executor.spawn({
664 let events = events.clone();
665 async move {
666 response.await.unwrap();
667 events.lock().push("response".to_string());
668 }
669 });
670
671 executor
672 .spawn({
673 let events = events.clone();
674 async move {
675 let incoming1 = client_incoming
676 .next()
677 .await
678 .unwrap()
679 .into_any()
680 .downcast::<TypedEnvelope<proto::Error>>()
681 .unwrap();
682 events.lock().push(incoming1.payload.message);
683 let incoming2 = client_incoming
684 .next()
685 .await
686 .unwrap()
687 .into_any()
688 .downcast::<TypedEnvelope<proto::Error>>()
689 .unwrap();
690 events.lock().push(incoming2.payload.message);
691
692 // Prevent the connection from being dropped
693 client_incoming.next().await;
694 }
695 })
696 .detach();
697
698 response_task.await;
699 assert_eq!(
700 &*events.lock(),
701 &[
702 "message 1".to_string(),
703 "message 2".to_string(),
704 "response".to_string()
705 ]
706 );
707 }
708
709 #[gpui::test(iterations = 50)]
710 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
711 let executor = cx.foreground();
712 let server = Peer::new();
713 let client = Peer::new();
714
715 let (client_to_server_conn, server_to_client_conn, _kill) =
716 Connection::in_memory(cx.background());
717 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
718 .add_test_connection(client_to_server_conn, cx.background())
719 .await;
720 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
721 .add_test_connection(server_to_client_conn, cx.background())
722 .await;
723
724 executor.spawn(io_task1).detach();
725 executor.spawn(io_task2).detach();
726
727 executor
728 .spawn(async move {
729 let request1 = server_incoming
730 .next()
731 .await
732 .unwrap()
733 .into_any()
734 .downcast::<TypedEnvelope<proto::Ping>>()
735 .unwrap();
736 let request2 = server_incoming
737 .next()
738 .await
739 .unwrap()
740 .into_any()
741 .downcast::<TypedEnvelope<proto::Ping>>()
742 .unwrap();
743
744 server
745 .send(
746 server_to_client_conn_id,
747 proto::Error {
748 message: "message 1".to_string(),
749 },
750 )
751 .unwrap();
752 server
753 .send(
754 server_to_client_conn_id,
755 proto::Error {
756 message: "message 2".to_string(),
757 },
758 )
759 .unwrap();
760 server.respond(request1.receipt(), proto::Ack {}).unwrap();
761 server.respond(request2.receipt(), proto::Ack {}).unwrap();
762
763 // Prevent the connection from being dropped
764 server_incoming.next().await;
765 })
766 .detach();
767
768 let events = Arc::new(Mutex::new(Vec::new()));
769
770 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
771 let request1_task = executor.spawn(request1);
772 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
773 let request2_task = executor.spawn({
774 let events = events.clone();
775 async move {
776 request2.await.unwrap();
777 events.lock().push("response 2".to_string());
778 }
779 });
780
781 executor
782 .spawn({
783 let events = events.clone();
784 async move {
785 let incoming1 = client_incoming
786 .next()
787 .await
788 .unwrap()
789 .into_any()
790 .downcast::<TypedEnvelope<proto::Error>>()
791 .unwrap();
792 events.lock().push(incoming1.payload.message);
793 let incoming2 = client_incoming
794 .next()
795 .await
796 .unwrap()
797 .into_any()
798 .downcast::<TypedEnvelope<proto::Error>>()
799 .unwrap();
800 events.lock().push(incoming2.payload.message);
801
802 // Prevent the connection from being dropped
803 client_incoming.next().await;
804 }
805 })
806 .detach();
807
808 // Allow the request to make some progress before dropping it.
809 cx.background().simulate_random_delay().await;
810 drop(request1_task);
811
812 request2_task.await;
813 assert_eq!(
814 &*events.lock(),
815 &[
816 "message 1".to_string(),
817 "message 2".to_string(),
818 "response 2".to_string()
819 ]
820 );
821 }
822
823 #[gpui::test(iterations = 50)]
824 async fn test_disconnect(cx: &mut TestAppContext) {
825 let executor = cx.foreground();
826
827 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
828
829 let client = Peer::new();
830 let (connection_id, io_handler, mut incoming) = client
831 .add_test_connection(client_conn, cx.background())
832 .await;
833
834 let (io_ended_tx, io_ended_rx) = oneshot::channel();
835 executor
836 .spawn(async move {
837 io_handler.await.ok();
838 io_ended_tx.send(()).unwrap();
839 })
840 .detach();
841
842 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
843 executor
844 .spawn(async move {
845 incoming.next().await;
846 messages_ended_tx.send(()).unwrap();
847 })
848 .detach();
849
850 client.disconnect(connection_id);
851
852 let _ = io_ended_rx.await;
853 let _ = messages_ended_rx.await;
854 assert!(server_conn
855 .send(WebSocketMessage::Binary(vec![]))
856 .await
857 .is_err());
858 }
859
860 #[gpui::test(iterations = 50)]
861 async fn test_io_error(cx: &mut TestAppContext) {
862 let executor = cx.foreground();
863 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
864
865 let client = Peer::new();
866 let (connection_id, io_handler, mut incoming) = client
867 .add_test_connection(client_conn, cx.background())
868 .await;
869 executor.spawn(io_handler).detach();
870 executor
871 .spawn(async move { incoming.next().await })
872 .detach();
873
874 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
875 let _request = server_conn.rx.next().await.unwrap().unwrap();
876
877 drop(server_conn);
878 assert_eq!(
879 response.await.unwrap_err().to_string(),
880 "connection was closed"
881 );
882 }
883}