1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::BoxStream,
10 FutureExt, SinkExt, StreamExt, TryFutureExt,
11};
12use parking_lot::{Mutex, RwLock};
13use serde::{ser::SerializeStruct, Serialize};
14use std::{fmt, sync::atomic::Ordering::SeqCst};
15use std::{
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24use tracing::instrument;
25
26#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
27pub struct ConnectionId {
28 pub owner_id: u32,
29 pub id: u32,
30}
31
32impl Into<PeerId> for ConnectionId {
33 fn into(self) -> PeerId {
34 PeerId {
35 owner_id: self.owner_id,
36 id: self.id,
37 }
38 }
39}
40
41impl From<PeerId> for ConnectionId {
42 fn from(peer_id: PeerId) -> Self {
43 Self {
44 owner_id: peer_id.owner_id,
45 id: peer_id.id,
46 }
47 }
48}
49
50impl fmt::Display for ConnectionId {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 write!(f, "{}/{}", self.owner_id, self.id)
53 }
54}
55
56pub struct Receipt<T> {
57 pub sender_id: ConnectionId,
58 pub message_id: u32,
59 payload_type: PhantomData<T>,
60}
61
62impl<T> Clone for Receipt<T> {
63 fn clone(&self) -> Self {
64 Self {
65 sender_id: self.sender_id,
66 message_id: self.message_id,
67 payload_type: PhantomData,
68 }
69 }
70}
71
72impl<T> Copy for Receipt<T> {}
73
74#[derive(Clone, Debug)]
75pub struct TypedEnvelope<T> {
76 pub sender_id: ConnectionId,
77 pub original_sender_id: Option<PeerId>,
78 pub message_id: u32,
79 pub payload: T,
80}
81
82impl<T> TypedEnvelope<T> {
83 pub fn original_sender_id(&self) -> Result<PeerId> {
84 self.original_sender_id
85 .ok_or_else(|| anyhow!("missing original_sender_id"))
86 }
87}
88
89impl<T: RequestMessage> TypedEnvelope<T> {
90 pub fn receipt(&self) -> Receipt<T> {
91 Receipt {
92 sender_id: self.sender_id,
93 message_id: self.message_id,
94 payload_type: PhantomData,
95 }
96 }
97}
98
99pub struct Peer {
100 epoch: AtomicU32,
101 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
102 next_connection_id: AtomicU32,
103}
104
105#[derive(Clone, Serialize)]
106pub struct ConnectionState {
107 #[serde(skip)]
108 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
109 next_message_id: Arc<AtomicU32>,
110 #[allow(clippy::type_complexity)]
111 #[serde(skip)]
112 response_channels:
113 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
114}
115
116const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
117const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
118pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
119
120impl Peer {
121 pub fn new(epoch: u32) -> Arc<Self> {
122 Arc::new(Self {
123 epoch: AtomicU32::new(epoch),
124 connections: Default::default(),
125 next_connection_id: Default::default(),
126 })
127 }
128
129 pub fn epoch(&self) -> u32 {
130 self.epoch.load(SeqCst)
131 }
132
133 #[instrument(skip_all)]
134 pub fn add_connection<F, Fut, Out>(
135 self: &Arc<Self>,
136 connection: Connection,
137 create_timer: F,
138 ) -> (
139 ConnectionId,
140 impl Future<Output = anyhow::Result<()>> + Send,
141 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
142 )
143 where
144 F: Send + Fn(Duration) -> Fut,
145 Fut: Send + Future<Output = Out>,
146 Out: Send,
147 {
148 // For outgoing messages, use an unbounded channel so that application code
149 // can always send messages without yielding. For incoming messages, use a
150 // bounded channel so that other peers will receive backpressure if they send
151 // messages faster than this peer can process them.
152 #[cfg(any(test, feature = "test-support"))]
153 const INCOMING_BUFFER_SIZE: usize = 1;
154 #[cfg(not(any(test, feature = "test-support")))]
155 const INCOMING_BUFFER_SIZE: usize = 64;
156 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
157 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
158
159 let connection_id = ConnectionId {
160 owner_id: self.epoch.load(SeqCst),
161 id: self.next_connection_id.fetch_add(1, SeqCst),
162 };
163 let connection_state = ConnectionState {
164 outgoing_tx,
165 next_message_id: Default::default(),
166 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
167 };
168 let mut writer = MessageStream::new(connection.tx);
169 let mut reader = MessageStream::new(connection.rx);
170
171 let this = self.clone();
172 let response_channels = connection_state.response_channels.clone();
173 let handle_io = async move {
174 tracing::trace!(%connection_id, "handle io future: start");
175
176 let _end_connection = util::defer(|| {
177 response_channels.lock().take();
178 this.connections.write().remove(&connection_id);
179 tracing::trace!(%connection_id, "handle io future: end");
180 });
181
182 // Send messages on this frequency so the connection isn't closed.
183 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
184 futures::pin_mut!(keepalive_timer);
185
186 // Disconnect if we don't receive messages at least this frequently.
187 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
188 futures::pin_mut!(receive_timeout);
189
190 loop {
191 tracing::trace!(%connection_id, "outer loop iteration start");
192 let read_message = reader.read().fuse();
193 futures::pin_mut!(read_message);
194
195 loop {
196 tracing::trace!(%connection_id, "inner loop iteration start");
197 futures::select_biased! {
198 outgoing = outgoing_rx.next().fuse() => match outgoing {
199 Some(outgoing) => {
200 tracing::trace!(%connection_id, "outgoing rpc message: writing");
201 futures::select_biased! {
202 result = writer.write(outgoing).fuse() => {
203 tracing::trace!(%connection_id, "outgoing rpc message: done writing");
204 result.context("failed to write RPC message")?;
205 tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
206 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
207 }
208 _ = create_timer(WRITE_TIMEOUT).fuse() => {
209 tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
210 Err(anyhow!("timed out writing message"))?;
211 }
212 }
213 }
214 None => {
215 tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
216 return Ok(())
217 },
218 },
219 _ = keepalive_timer => {
220 tracing::trace!(%connection_id, "keepalive interval: pinging");
221 futures::select_biased! {
222 result = writer.write(proto::Message::Ping).fuse() => {
223 tracing::trace!(%connection_id, "keepalive interval: done pinging");
224 result.context("failed to send keepalive")?;
225 tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
226 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
227 }
228 _ = create_timer(WRITE_TIMEOUT).fuse() => {
229 tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
230 Err(anyhow!("timed out sending keepalive"))?;
231 }
232 }
233 }
234 incoming = read_message => {
235 let incoming = incoming.context("error reading rpc message from socket")?;
236 tracing::trace!(%connection_id, "incoming rpc message: received");
237 tracing::trace!(%connection_id, "receive timeout: resetting");
238 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
239 if let proto::Message::Envelope(incoming) = incoming {
240 tracing::trace!(%connection_id, "incoming rpc message: processing");
241 futures::select_biased! {
242 result = incoming_tx.send(incoming).fuse() => match result {
243 Ok(_) => {
244 tracing::trace!(%connection_id, "incoming rpc message: processed");
245 }
246 Err(_) => {
247 tracing::trace!(%connection_id, "incoming rpc message: channel closed");
248 return Ok(())
249 }
250 },
251 _ = create_timer(WRITE_TIMEOUT).fuse() => {
252 tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
253 Err(anyhow!("timed out processing incoming message"))?
254 }
255 }
256 }
257 break;
258 },
259 _ = receive_timeout => {
260 tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
261 Err(anyhow!("delay between messages too long"))?
262 }
263 }
264 }
265 }
266 };
267
268 let response_channels = connection_state.response_channels.clone();
269 self.connections
270 .write()
271 .insert(connection_id, connection_state);
272
273 let incoming_rx = incoming_rx.filter_map(move |incoming| {
274 let response_channels = response_channels.clone();
275 async move {
276 let message_id = incoming.id;
277 tracing::trace!(?incoming, "incoming message future: start");
278 let _end = util::defer(move || {
279 tracing::trace!(%connection_id, message_id, "incoming message future: end");
280 });
281
282 if let Some(responding_to) = incoming.responding_to {
283 tracing::trace!(
284 %connection_id,
285 message_id,
286 responding_to,
287 "incoming response: received"
288 );
289 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
290 if let Some(tx) = channel {
291 let requester_resumed = oneshot::channel();
292 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
293 tracing::trace!(
294 %connection_id,
295 message_id,
296 responding_to = responding_to,
297 ?error,
298 "incoming response: request future dropped",
299 );
300 }
301
302 tracing::trace!(
303 %connection_id,
304 message_id,
305 responding_to,
306 "incoming response: waiting to resume requester"
307 );
308 let _ = requester_resumed.1.await;
309 tracing::trace!(
310 %connection_id,
311 message_id,
312 responding_to,
313 "incoming response: requester resumed"
314 );
315 } else {
316 tracing::warn!(
317 %connection_id,
318 message_id,
319 responding_to,
320 "incoming response: unknown request"
321 );
322 }
323
324 None
325 } else {
326 tracing::trace!(%connection_id, message_id, "incoming message: received");
327 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
328 tracing::error!(
329 %connection_id,
330 message_id,
331 "unable to construct a typed envelope"
332 );
333 None
334 })
335 }
336 }
337 });
338 (connection_id, handle_io, incoming_rx.boxed())
339 }
340
341 #[cfg(any(test, feature = "test-support"))]
342 pub fn add_test_connection(
343 self: &Arc<Self>,
344 connection: Connection,
345 executor: gpui::BackgroundExecutor,
346 ) -> (
347 ConnectionId,
348 impl Future<Output = anyhow::Result<()>> + Send,
349 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
350 ) {
351 let executor = executor.clone();
352 self.add_connection(connection, move |duration| executor.timer(duration))
353 }
354
355 pub fn disconnect(&self, connection_id: ConnectionId) {
356 self.connections.write().remove(&connection_id);
357 }
358
359 pub fn reset(&self, epoch: u32) {
360 self.teardown();
361 self.next_connection_id.store(0, SeqCst);
362 self.epoch.store(epoch, SeqCst);
363 }
364
365 pub fn teardown(&self) {
366 self.connections.write().clear();
367 }
368
369 pub fn request<T: RequestMessage>(
370 &self,
371 receiver_id: ConnectionId,
372 request: T,
373 ) -> impl Future<Output = Result<T::Response>> {
374 self.request_internal(None, receiver_id, request)
375 .map_ok(|envelope| envelope.payload)
376 }
377
378 pub fn request_envelope<T: RequestMessage>(
379 &self,
380 receiver_id: ConnectionId,
381 request: T,
382 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
383 self.request_internal(None, receiver_id, request)
384 }
385
386 pub fn forward_request<T: RequestMessage>(
387 &self,
388 sender_id: ConnectionId,
389 receiver_id: ConnectionId,
390 request: T,
391 ) -> impl Future<Output = Result<T::Response>> {
392 self.request_internal(Some(sender_id), receiver_id, request)
393 .map_ok(|envelope| envelope.payload)
394 }
395
396 pub fn request_internal<T: RequestMessage>(
397 &self,
398 original_sender_id: Option<ConnectionId>,
399 receiver_id: ConnectionId,
400 request: T,
401 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
402 let (tx, rx) = oneshot::channel();
403 let send = self.connection_state(receiver_id).and_then(|connection| {
404 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
405 connection
406 .response_channels
407 .lock()
408 .as_mut()
409 .ok_or_else(|| anyhow!("connection was closed"))?
410 .insert(message_id, tx);
411 connection
412 .outgoing_tx
413 .unbounded_send(proto::Message::Envelope(request.into_envelope(
414 message_id,
415 None,
416 original_sender_id.map(Into::into),
417 )))
418 .map_err(|_| anyhow!("connection was closed"))?;
419 Ok(())
420 });
421 async move {
422 send?;
423 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
424
425 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
426 Err(anyhow!(
427 "RPC request {} failed - {}",
428 T::NAME,
429 error.message
430 ))
431 } else {
432 Ok(TypedEnvelope {
433 message_id: response.id,
434 sender_id: receiver_id,
435 original_sender_id: response.original_sender_id,
436 payload: T::Response::from_envelope(response)
437 .ok_or_else(|| anyhow!("received response of the wrong type"))?,
438 })
439 }
440 }
441 }
442
443 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
444 let connection = self.connection_state(receiver_id)?;
445 let message_id = connection
446 .next_message_id
447 .fetch_add(1, atomic::Ordering::SeqCst);
448 connection
449 .outgoing_tx
450 .unbounded_send(proto::Message::Envelope(
451 message.into_envelope(message_id, None, None),
452 ))?;
453 Ok(())
454 }
455
456 pub fn forward_send<T: EnvelopedMessage>(
457 &self,
458 sender_id: ConnectionId,
459 receiver_id: ConnectionId,
460 message: T,
461 ) -> Result<()> {
462 let connection = self.connection_state(receiver_id)?;
463 let message_id = connection
464 .next_message_id
465 .fetch_add(1, atomic::Ordering::SeqCst);
466 connection
467 .outgoing_tx
468 .unbounded_send(proto::Message::Envelope(message.into_envelope(
469 message_id,
470 None,
471 Some(sender_id.into()),
472 )))?;
473 Ok(())
474 }
475
476 pub fn respond<T: RequestMessage>(
477 &self,
478 receipt: Receipt<T>,
479 response: T::Response,
480 ) -> Result<()> {
481 let connection = self.connection_state(receipt.sender_id)?;
482 let message_id = connection
483 .next_message_id
484 .fetch_add(1, atomic::Ordering::SeqCst);
485 connection
486 .outgoing_tx
487 .unbounded_send(proto::Message::Envelope(response.into_envelope(
488 message_id,
489 Some(receipt.message_id),
490 None,
491 )))?;
492 Ok(())
493 }
494
495 pub fn respond_with_error<T: RequestMessage>(
496 &self,
497 receipt: Receipt<T>,
498 response: proto::Error,
499 ) -> Result<()> {
500 let connection = self.connection_state(receipt.sender_id)?;
501 let message_id = connection
502 .next_message_id
503 .fetch_add(1, atomic::Ordering::SeqCst);
504 connection
505 .outgoing_tx
506 .unbounded_send(proto::Message::Envelope(response.into_envelope(
507 message_id,
508 Some(receipt.message_id),
509 None,
510 )))?;
511 Ok(())
512 }
513
514 pub fn respond_with_unhandled_message(
515 &self,
516 envelope: Box<dyn AnyTypedEnvelope>,
517 ) -> Result<()> {
518 let connection = self.connection_state(envelope.sender_id())?;
519 let response = proto::Error {
520 message: format!("message {} was not handled", envelope.payload_type_name()),
521 };
522 let message_id = connection
523 .next_message_id
524 .fetch_add(1, atomic::Ordering::SeqCst);
525 connection
526 .outgoing_tx
527 .unbounded_send(proto::Message::Envelope(response.into_envelope(
528 message_id,
529 Some(envelope.message_id()),
530 None,
531 )))?;
532 Ok(())
533 }
534
535 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
536 let connections = self.connections.read();
537 let connection = connections
538 .get(&connection_id)
539 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
540 Ok(connection.clone())
541 }
542}
543
544impl Serialize for Peer {
545 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
546 where
547 S: serde::Serializer,
548 {
549 let mut state = serializer.serialize_struct("Peer", 2)?;
550 state.serialize_field("connections", &*self.connections.read())?;
551 state.end()
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558 use crate::TypedEnvelope;
559 use async_tungstenite::tungstenite::Message as WebSocketMessage;
560 use gpui::TestAppContext;
561
562 fn init_logger() {
563 if std::env::var("RUST_LOG").is_ok() {
564 env_logger::init();
565 }
566 }
567
568 #[gpui::test(iterations = 50)]
569 async fn test_request_response(cx: &mut TestAppContext) {
570 init_logger();
571
572 let executor = cx.executor();
573
574 // create 2 clients connected to 1 server
575 let server = Peer::new(0);
576 let client1 = Peer::new(0);
577 let client2 = Peer::new(0);
578
579 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
580 Connection::in_memory(cx.executor());
581 let (client1_conn_id, io_task1, client1_incoming) =
582 client1.add_test_connection(client1_to_server_conn, cx.executor());
583 let (_, io_task2, server_incoming1) =
584 server.add_test_connection(server_to_client_1_conn, cx.executor());
585
586 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
587 Connection::in_memory(cx.executor());
588 let (client2_conn_id, io_task3, client2_incoming) =
589 client2.add_test_connection(client2_to_server_conn, cx.executor());
590 let (_, io_task4, server_incoming2) =
591 server.add_test_connection(server_to_client_2_conn, cx.executor());
592
593 executor.spawn(io_task1).detach();
594 executor.spawn(io_task2).detach();
595 executor.spawn(io_task3).detach();
596 executor.spawn(io_task4).detach();
597 executor
598 .spawn(handle_messages(server_incoming1, server.clone()))
599 .detach();
600 executor
601 .spawn(handle_messages(client1_incoming, client1.clone()))
602 .detach();
603 executor
604 .spawn(handle_messages(server_incoming2, server.clone()))
605 .detach();
606 executor
607 .spawn(handle_messages(client2_incoming, client2.clone()))
608 .detach();
609
610 assert_eq!(
611 client1
612 .request(client1_conn_id, proto::Ping {},)
613 .await
614 .unwrap(),
615 proto::Ack {}
616 );
617
618 assert_eq!(
619 client2
620 .request(client2_conn_id, proto::Ping {},)
621 .await
622 .unwrap(),
623 proto::Ack {}
624 );
625
626 assert_eq!(
627 client1
628 .request(client1_conn_id, proto::Test { id: 1 },)
629 .await
630 .unwrap(),
631 proto::Test { id: 1 }
632 );
633
634 assert_eq!(
635 client2
636 .request(client2_conn_id, proto::Test { id: 2 })
637 .await
638 .unwrap(),
639 proto::Test { id: 2 }
640 );
641
642 client1.disconnect(client1_conn_id);
643 client2.disconnect(client1_conn_id);
644
645 async fn handle_messages(
646 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
647 peer: Arc<Peer>,
648 ) -> Result<()> {
649 while let Some(envelope) = messages.next().await {
650 let envelope = envelope.into_any();
651 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
652 let receipt = envelope.receipt();
653 peer.respond(receipt, proto::Ack {})?
654 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
655 {
656 peer.respond(envelope.receipt(), envelope.payload.clone())?
657 } else {
658 panic!("unknown message type");
659 }
660 }
661
662 Ok(())
663 }
664 }
665
666 #[gpui::test(iterations = 50)]
667 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
668 let executor = cx.executor();
669 let server = Peer::new(0);
670 let client = Peer::new(0);
671
672 let (client_to_server_conn, server_to_client_conn, _kill) =
673 Connection::in_memory(executor.clone());
674 let (client_to_server_conn_id, io_task1, mut client_incoming) =
675 client.add_test_connection(client_to_server_conn, executor.clone());
676
677 let (server_to_client_conn_id, io_task2, mut server_incoming) =
678 server.add_test_connection(server_to_client_conn, executor.clone());
679
680 executor.spawn(io_task1).detach();
681 executor.spawn(io_task2).detach();
682
683 executor
684 .spawn(async move {
685 let future = server_incoming.next().await;
686 let request = future
687 .unwrap()
688 .into_any()
689 .downcast::<TypedEnvelope<proto::Ping>>()
690 .unwrap();
691
692 server
693 .send(
694 server_to_client_conn_id,
695 proto::Error {
696 message: "message 1".to_string(),
697 },
698 )
699 .unwrap();
700 server
701 .send(
702 server_to_client_conn_id,
703 proto::Error {
704 message: "message 2".to_string(),
705 },
706 )
707 .unwrap();
708 server.respond(request.receipt(), proto::Ack {}).unwrap();
709
710 // Prevent the connection from being dropped
711 server_incoming.next().await;
712 })
713 .detach();
714
715 let events = Arc::new(Mutex::new(Vec::new()));
716
717 let response = client.request(client_to_server_conn_id, proto::Ping {});
718 let response_task = executor.spawn({
719 let events = events.clone();
720 async move {
721 response.await.unwrap();
722 events.lock().push("response".to_string());
723 }
724 });
725
726 executor
727 .spawn({
728 let events = events.clone();
729 async move {
730 let incoming1 = client_incoming
731 .next()
732 .await
733 .unwrap()
734 .into_any()
735 .downcast::<TypedEnvelope<proto::Error>>()
736 .unwrap();
737 events.lock().push(incoming1.payload.message);
738 let incoming2 = client_incoming
739 .next()
740 .await
741 .unwrap()
742 .into_any()
743 .downcast::<TypedEnvelope<proto::Error>>()
744 .unwrap();
745 events.lock().push(incoming2.payload.message);
746
747 // Prevent the connection from being dropped
748 client_incoming.next().await;
749 }
750 })
751 .detach();
752
753 response_task.await;
754 assert_eq!(
755 &*events.lock(),
756 &[
757 "message 1".to_string(),
758 "message 2".to_string(),
759 "response".to_string()
760 ]
761 );
762 }
763
764 #[gpui::test(iterations = 50)]
765 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
766 let executor = cx.executor();
767 let server = Peer::new(0);
768 let client = Peer::new(0);
769
770 let (client_to_server_conn, server_to_client_conn, _kill) =
771 Connection::in_memory(cx.executor());
772 let (client_to_server_conn_id, io_task1, mut client_incoming) =
773 client.add_test_connection(client_to_server_conn, cx.executor());
774 let (server_to_client_conn_id, io_task2, mut server_incoming) =
775 server.add_test_connection(server_to_client_conn, cx.executor());
776
777 executor.spawn(io_task1).detach();
778 executor.spawn(io_task2).detach();
779
780 executor
781 .spawn(async move {
782 let request1 = server_incoming
783 .next()
784 .await
785 .unwrap()
786 .into_any()
787 .downcast::<TypedEnvelope<proto::Ping>>()
788 .unwrap();
789 let request2 = server_incoming
790 .next()
791 .await
792 .unwrap()
793 .into_any()
794 .downcast::<TypedEnvelope<proto::Ping>>()
795 .unwrap();
796
797 server
798 .send(
799 server_to_client_conn_id,
800 proto::Error {
801 message: "message 1".to_string(),
802 },
803 )
804 .unwrap();
805 server
806 .send(
807 server_to_client_conn_id,
808 proto::Error {
809 message: "message 2".to_string(),
810 },
811 )
812 .unwrap();
813 server.respond(request1.receipt(), proto::Ack {}).unwrap();
814 server.respond(request2.receipt(), proto::Ack {}).unwrap();
815
816 // Prevent the connection from being dropped
817 server_incoming.next().await;
818 })
819 .detach();
820
821 let events = Arc::new(Mutex::new(Vec::new()));
822
823 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
824 let request1_task = executor.spawn(request1);
825 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
826 let request2_task = executor.spawn({
827 let events = events.clone();
828 async move {
829 request2.await.unwrap();
830 events.lock().push("response 2".to_string());
831 }
832 });
833
834 executor
835 .spawn({
836 let events = events.clone();
837 async move {
838 let incoming1 = client_incoming
839 .next()
840 .await
841 .unwrap()
842 .into_any()
843 .downcast::<TypedEnvelope<proto::Error>>()
844 .unwrap();
845 events.lock().push(incoming1.payload.message);
846 let incoming2 = client_incoming
847 .next()
848 .await
849 .unwrap()
850 .into_any()
851 .downcast::<TypedEnvelope<proto::Error>>()
852 .unwrap();
853 events.lock().push(incoming2.payload.message);
854
855 // Prevent the connection from being dropped
856 client_incoming.next().await;
857 }
858 })
859 .detach();
860
861 // Allow the request to make some progress before dropping it.
862 cx.executor().simulate_random_delay().await;
863 drop(request1_task);
864
865 request2_task.await;
866 assert_eq!(
867 &*events.lock(),
868 &[
869 "message 1".to_string(),
870 "message 2".to_string(),
871 "response 2".to_string()
872 ]
873 );
874 }
875
876 #[gpui::test(iterations = 50)]
877 async fn test_disconnect(cx: &mut TestAppContext) {
878 let executor = cx.executor();
879
880 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
881
882 let client = Peer::new(0);
883 let (connection_id, io_handler, mut incoming) =
884 client.add_test_connection(client_conn, executor.clone());
885
886 let (io_ended_tx, io_ended_rx) = oneshot::channel();
887 executor
888 .spawn(async move {
889 io_handler.await.ok();
890 io_ended_tx.send(()).unwrap();
891 })
892 .detach();
893
894 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
895 executor
896 .spawn(async move {
897 incoming.next().await;
898 messages_ended_tx.send(()).unwrap();
899 })
900 .detach();
901
902 client.disconnect(connection_id);
903
904 let _ = io_ended_rx.await;
905 let _ = messages_ended_rx.await;
906 assert!(server_conn
907 .send(WebSocketMessage::Binary(vec![]))
908 .await
909 .is_err());
910 }
911
912 #[gpui::test(iterations = 50)]
913 async fn test_io_error(cx: &mut TestAppContext) {
914 let executor = cx.executor();
915 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
916
917 let client = Peer::new(0);
918 let (connection_id, io_handler, mut incoming) =
919 client.add_test_connection(client_conn, executor.clone());
920 executor.spawn(io_handler).detach();
921 executor
922 .spawn(async move { incoming.next().await })
923 .detach();
924
925 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
926 let _request = server_conn.rx.next().await.unwrap().unwrap();
927
928 drop(server_conn);
929 assert_eq!(
930 response.await.unwrap_err().to_string(),
931 "connection was closed"
932 );
933 }
934}