1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::BoxStream,
10 FutureExt, SinkExt, StreamExt, TryFutureExt,
11};
12use parking_lot::{Mutex, RwLock};
13use serde::{ser::SerializeStruct, Serialize};
14use std::{fmt, sync::atomic::Ordering::SeqCst};
15use std::{
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24use tracing::instrument;
25
26#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
27pub struct ConnectionId {
28 pub owner_id: u32,
29 pub id: u32,
30}
31
32impl Into<PeerId> for ConnectionId {
33 fn into(self) -> PeerId {
34 PeerId {
35 owner_id: self.owner_id,
36 id: self.id,
37 }
38 }
39}
40
41impl From<PeerId> for ConnectionId {
42 fn from(peer_id: PeerId) -> Self {
43 Self {
44 owner_id: peer_id.owner_id,
45 id: peer_id.id,
46 }
47 }
48}
49
50impl fmt::Display for ConnectionId {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 write!(f, "{}/{}", self.owner_id, self.id)
53 }
54}
55
56pub struct Receipt<T> {
57 pub sender_id: ConnectionId,
58 pub message_id: u32,
59 payload_type: PhantomData<T>,
60}
61
62impl<T> Clone for Receipt<T> {
63 fn clone(&self) -> Self {
64 Self {
65 sender_id: self.sender_id,
66 message_id: self.message_id,
67 payload_type: PhantomData,
68 }
69 }
70}
71
72impl<T> Copy for Receipt<T> {}
73
74#[derive(Clone, Debug)]
75pub struct TypedEnvelope<T> {
76 pub sender_id: ConnectionId,
77 pub original_sender_id: Option<PeerId>,
78 pub message_id: u32,
79 pub payload: T,
80}
81
82impl<T> TypedEnvelope<T> {
83 pub fn original_sender_id(&self) -> Result<PeerId> {
84 self.original_sender_id
85 .ok_or_else(|| anyhow!("missing original_sender_id"))
86 }
87}
88
89impl<T: RequestMessage> TypedEnvelope<T> {
90 pub fn receipt(&self) -> Receipt<T> {
91 Receipt {
92 sender_id: self.sender_id,
93 message_id: self.message_id,
94 payload_type: PhantomData,
95 }
96 }
97}
98
99pub struct Peer {
100 epoch: AtomicU32,
101 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
102 next_connection_id: AtomicU32,
103}
104
105#[derive(Clone, Serialize)]
106pub struct ConnectionState {
107 #[serde(skip)]
108 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
109 next_message_id: Arc<AtomicU32>,
110 #[allow(clippy::type_complexity)]
111 #[serde(skip)]
112 response_channels:
113 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
114}
115
116const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
117const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
118pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
119
120impl Peer {
121 pub fn new(epoch: u32) -> Arc<Self> {
122 Arc::new(Self {
123 epoch: AtomicU32::new(epoch),
124 connections: Default::default(),
125 next_connection_id: Default::default(),
126 })
127 }
128
129 pub fn epoch(&self) -> u32 {
130 self.epoch.load(SeqCst)
131 }
132
133 #[instrument(skip_all)]
134 pub fn add_connection<F, Fut, Out>(
135 self: &Arc<Self>,
136 connection: Connection,
137 create_timer: F,
138 ) -> (
139 ConnectionId,
140 impl Future<Output = anyhow::Result<()>> + Send,
141 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
142 )
143 where
144 F: Send + Fn(Duration) -> Fut,
145 Fut: Send + Future<Output = Out>,
146 Out: Send,
147 {
148 // For outgoing messages, use an unbounded channel so that application code
149 // can always send messages without yielding. For incoming messages, use a
150 // bounded channel so that other peers will receive backpressure if they send
151 // messages faster than this peer can process them.
152 #[cfg(any(test, feature = "test-support"))]
153 const INCOMING_BUFFER_SIZE: usize = 1;
154 #[cfg(not(any(test, feature = "test-support")))]
155 const INCOMING_BUFFER_SIZE: usize = 64;
156 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
157 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
158
159 let connection_id = ConnectionId {
160 owner_id: self.epoch.load(SeqCst),
161 id: self.next_connection_id.fetch_add(1, SeqCst),
162 };
163 let connection_state = ConnectionState {
164 outgoing_tx,
165 next_message_id: Default::default(),
166 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
167 };
168 let mut writer = MessageStream::new(connection.tx);
169 let mut reader = MessageStream::new(connection.rx);
170
171 let this = self.clone();
172 let response_channels = connection_state.response_channels.clone();
173 let handle_io = async move {
174 tracing::trace!(%connection_id, "handle io future: start");
175
176 let _end_connection = util::defer(|| {
177 response_channels.lock().take();
178 this.connections.write().remove(&connection_id);
179 tracing::trace!(%connection_id, "handle io future: end");
180 });
181
182 // Send messages on this frequency so the connection isn't closed.
183 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
184 futures::pin_mut!(keepalive_timer);
185
186 // Disconnect if we don't receive messages at least this frequently.
187 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
188 futures::pin_mut!(receive_timeout);
189
190 loop {
191 tracing::trace!(%connection_id, "outer loop iteration start");
192 let read_message = reader.read().fuse();
193 futures::pin_mut!(read_message);
194
195 loop {
196 tracing::trace!(%connection_id, "inner loop iteration start");
197 futures::select_biased! {
198 outgoing = outgoing_rx.next().fuse() => match outgoing {
199 Some(outgoing) => {
200 tracing::trace!(%connection_id, "outgoing rpc message: writing");
201 futures::select_biased! {
202 result = writer.write(outgoing).fuse() => {
203 tracing::trace!(%connection_id, "outgoing rpc message: done writing");
204 result.context("failed to write RPC message")?;
205 tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
206 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
207 }
208 _ = create_timer(WRITE_TIMEOUT).fuse() => {
209 tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
210 Err(anyhow!("timed out writing message"))?;
211 }
212 }
213 }
214 None => {
215 tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
216 return Ok(())
217 },
218 },
219 _ = keepalive_timer => {
220 tracing::trace!(%connection_id, "keepalive interval: pinging");
221 futures::select_biased! {
222 result = writer.write(proto::Message::Ping).fuse() => {
223 tracing::trace!(%connection_id, "keepalive interval: done pinging");
224 result.context("failed to send keepalive")?;
225 tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
226 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
227 }
228 _ = create_timer(WRITE_TIMEOUT).fuse() => {
229 tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
230 Err(anyhow!("timed out sending keepalive"))?;
231 }
232 }
233 }
234 incoming = read_message => {
235 let incoming = incoming.context("error reading rpc message from socket")?;
236 tracing::trace!(%connection_id, "incoming rpc message: received");
237 tracing::trace!(%connection_id, "receive timeout: resetting");
238 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
239 if let proto::Message::Envelope(incoming) = incoming {
240 tracing::trace!(%connection_id, "incoming rpc message: processing");
241 futures::select_biased! {
242 result = incoming_tx.send(incoming).fuse() => match result {
243 Ok(_) => {
244 tracing::trace!(%connection_id, "incoming rpc message: processed");
245 }
246 Err(_) => {
247 tracing::trace!(%connection_id, "incoming rpc message: channel closed");
248 return Ok(())
249 }
250 },
251 _ = create_timer(WRITE_TIMEOUT).fuse() => {
252 tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
253 Err(anyhow!("timed out processing incoming message"))?
254 }
255 }
256 }
257 break;
258 },
259 _ = receive_timeout => {
260 tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
261 Err(anyhow!("delay between messages too long"))?
262 }
263 }
264 }
265 }
266 };
267
268 let response_channels = connection_state.response_channels.clone();
269 self.connections
270 .write()
271 .insert(connection_id, connection_state);
272
273 let incoming_rx = incoming_rx.filter_map(move |incoming| {
274 let response_channels = response_channels.clone();
275 async move {
276 let message_id = incoming.id;
277 tracing::trace!(?incoming, "incoming message future: start");
278 let _end = util::defer(move || {
279 tracing::trace!(%connection_id, message_id, "incoming message future: end");
280 });
281
282 if let Some(responding_to) = incoming.responding_to {
283 tracing::trace!(
284 %connection_id,
285 message_id,
286 responding_to,
287 "incoming response: received"
288 );
289 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
290 if let Some(tx) = channel {
291 let requester_resumed = oneshot::channel();
292 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
293 tracing::trace!(
294 %connection_id,
295 message_id,
296 responding_to = responding_to,
297 ?error,
298 "incoming response: request future dropped",
299 );
300 }
301
302 tracing::trace!(
303 %connection_id,
304 message_id,
305 responding_to,
306 "incoming response: waiting to resume requester"
307 );
308 let _ = requester_resumed.1.await;
309 tracing::trace!(
310 %connection_id,
311 message_id,
312 responding_to,
313 "incoming response: requester resumed"
314 );
315 } else {
316 tracing::warn!(
317 %connection_id,
318 message_id,
319 responding_to,
320 "incoming response: unknown request"
321 );
322 }
323
324 None
325 } else {
326 tracing::trace!(%connection_id, message_id, "incoming message: received");
327 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
328 tracing::error!(
329 %connection_id,
330 message_id,
331 "unable to construct a typed envelope"
332 );
333 None
334 })
335 }
336 }
337 });
338 (connection_id, handle_io, incoming_rx.boxed())
339 }
340
341 #[cfg(any(test, feature = "test-support"))]
342 pub fn add_test_connection(
343 self: &Arc<Self>,
344 connection: Connection,
345 executor: Arc<gpui::executor::Background>,
346 ) -> (
347 ConnectionId,
348 impl Future<Output = anyhow::Result<()>> + Send,
349 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
350 ) {
351 let executor = executor.clone();
352 self.add_connection(connection, move |duration| executor.timer(duration))
353 }
354
355 pub fn disconnect(&self, connection_id: ConnectionId) {
356 self.connections.write().remove(&connection_id);
357 }
358
359 pub fn reset(&self, epoch: u32) {
360 self.teardown();
361 self.next_connection_id.store(0, SeqCst);
362 self.epoch.store(epoch, SeqCst);
363 }
364
365 pub fn teardown(&self) {
366 self.connections.write().clear();
367 }
368
369 pub fn request<T: RequestMessage>(
370 &self,
371 receiver_id: ConnectionId,
372 request: T,
373 ) -> impl Future<Output = Result<T::Response>> {
374 self.request_internal(None, receiver_id, request)
375 .map_ok(|envelope| envelope.payload)
376 }
377
378 pub fn request_envelope<T: RequestMessage>(
379 &self,
380 receiver_id: ConnectionId,
381 request: T,
382 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
383 self.request_internal(None, receiver_id, request)
384 }
385
386 pub fn forward_request<T: RequestMessage>(
387 &self,
388 sender_id: ConnectionId,
389 receiver_id: ConnectionId,
390 request: T,
391 ) -> impl Future<Output = Result<T::Response>> {
392 self.request_internal(Some(sender_id), receiver_id, request)
393 .map_ok(|envelope| envelope.payload)
394 }
395
396 pub fn request_internal<T: RequestMessage>(
397 &self,
398 original_sender_id: Option<ConnectionId>,
399 receiver_id: ConnectionId,
400 request: T,
401 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
402 let (tx, rx) = oneshot::channel();
403 let send = self.connection_state(receiver_id).and_then(|connection| {
404 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
405 connection
406 .response_channels
407 .lock()
408 .as_mut()
409 .ok_or_else(|| anyhow!("connection was closed"))?
410 .insert(message_id, tx);
411 connection
412 .outgoing_tx
413 .unbounded_send(proto::Message::Envelope(request.into_envelope(
414 message_id,
415 None,
416 original_sender_id.map(Into::into),
417 )))
418 .map_err(|_| anyhow!("connection was closed"))?;
419 Ok(())
420 });
421 async move {
422 send?;
423 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
424
425 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
426 Err(anyhow!(
427 "RPC request {} failed - {}",
428 T::NAME,
429 error.message
430 ))
431 } else {
432 Ok(TypedEnvelope {
433 message_id: response.id,
434 sender_id: receiver_id,
435 original_sender_id: response.original_sender_id,
436 payload: T::Response::from_envelope(response)
437 .ok_or_else(|| anyhow!("received response of the wrong type"))?,
438 })
439 }
440 }
441 }
442
443 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
444 let connection = self.connection_state(receiver_id)?;
445 let message_id = connection
446 .next_message_id
447 .fetch_add(1, atomic::Ordering::SeqCst);
448 connection
449 .outgoing_tx
450 .unbounded_send(proto::Message::Envelope(
451 message.into_envelope(message_id, None, None),
452 ))?;
453 Ok(())
454 }
455
456 pub fn forward_send<T: EnvelopedMessage>(
457 &self,
458 sender_id: ConnectionId,
459 receiver_id: ConnectionId,
460 message: T,
461 ) -> Result<()> {
462 let connection = self.connection_state(receiver_id)?;
463 let message_id = connection
464 .next_message_id
465 .fetch_add(1, atomic::Ordering::SeqCst);
466 connection
467 .outgoing_tx
468 .unbounded_send(proto::Message::Envelope(message.into_envelope(
469 message_id,
470 None,
471 Some(sender_id.into()),
472 )))?;
473 Ok(())
474 }
475
476 pub fn respond<T: RequestMessage>(
477 &self,
478 receipt: Receipt<T>,
479 response: T::Response,
480 ) -> Result<()> {
481 let connection = self.connection_state(receipt.sender_id)?;
482 let message_id = connection
483 .next_message_id
484 .fetch_add(1, atomic::Ordering::SeqCst);
485 connection
486 .outgoing_tx
487 .unbounded_send(proto::Message::Envelope(response.into_envelope(
488 message_id,
489 Some(receipt.message_id),
490 None,
491 )))?;
492 Ok(())
493 }
494
495 pub fn respond_with_error<T: RequestMessage>(
496 &self,
497 receipt: Receipt<T>,
498 response: proto::Error,
499 ) -> Result<()> {
500 let connection = self.connection_state(receipt.sender_id)?;
501 let message_id = connection
502 .next_message_id
503 .fetch_add(1, atomic::Ordering::SeqCst);
504 connection
505 .outgoing_tx
506 .unbounded_send(proto::Message::Envelope(response.into_envelope(
507 message_id,
508 Some(receipt.message_id),
509 None,
510 )))?;
511 Ok(())
512 }
513
514 pub fn respond_with_unhandled_message(
515 &self,
516 envelope: Box<dyn AnyTypedEnvelope>,
517 ) -> Result<()> {
518 let connection = self.connection_state(envelope.sender_id())?;
519 let response = proto::Error {
520 message: format!("message {} was not handled", envelope.payload_type_name()),
521 };
522 let message_id = connection
523 .next_message_id
524 .fetch_add(1, atomic::Ordering::SeqCst);
525 connection
526 .outgoing_tx
527 .unbounded_send(proto::Message::Envelope(response.into_envelope(
528 message_id,
529 Some(envelope.message_id()),
530 None,
531 )))?;
532 Ok(())
533 }
534
535 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
536 let connections = self.connections.read();
537 let connection = connections
538 .get(&connection_id)
539 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
540 Ok(connection.clone())
541 }
542}
543
544impl Serialize for Peer {
545 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
546 where
547 S: serde::Serializer,
548 {
549 let mut state = serializer.serialize_struct("Peer", 2)?;
550 state.serialize_field("connections", &*self.connections.read())?;
551 state.end()
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558 use crate::TypedEnvelope;
559 use async_tungstenite::tungstenite::Message as WebSocketMessage;
560 use gpui::TestAppContext;
561
562 #[ctor::ctor]
563 fn init_logger() {
564 if std::env::var("RUST_LOG").is_ok() {
565 env_logger::init();
566 }
567 }
568
569 #[gpui::test(iterations = 50)]
570 async fn test_request_response(cx: &mut TestAppContext) {
571 let executor = cx.foreground();
572
573 // create 2 clients connected to 1 server
574 let server = Peer::new(0);
575 let client1 = Peer::new(0);
576 let client2 = Peer::new(0);
577
578 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
579 Connection::in_memory(cx.background());
580 let (client1_conn_id, io_task1, client1_incoming) =
581 client1.add_test_connection(client1_to_server_conn, cx.background());
582 let (_, io_task2, server_incoming1) =
583 server.add_test_connection(server_to_client_1_conn, cx.background());
584
585 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
586 Connection::in_memory(cx.background());
587 let (client2_conn_id, io_task3, client2_incoming) =
588 client2.add_test_connection(client2_to_server_conn, cx.background());
589 let (_, io_task4, server_incoming2) =
590 server.add_test_connection(server_to_client_2_conn, cx.background());
591
592 executor.spawn(io_task1).detach();
593 executor.spawn(io_task2).detach();
594 executor.spawn(io_task3).detach();
595 executor.spawn(io_task4).detach();
596 executor
597 .spawn(handle_messages(server_incoming1, server.clone()))
598 .detach();
599 executor
600 .spawn(handle_messages(client1_incoming, client1.clone()))
601 .detach();
602 executor
603 .spawn(handle_messages(server_incoming2, server.clone()))
604 .detach();
605 executor
606 .spawn(handle_messages(client2_incoming, client2.clone()))
607 .detach();
608
609 assert_eq!(
610 client1
611 .request(client1_conn_id, proto::Ping {},)
612 .await
613 .unwrap(),
614 proto::Ack {}
615 );
616
617 assert_eq!(
618 client2
619 .request(client2_conn_id, proto::Ping {},)
620 .await
621 .unwrap(),
622 proto::Ack {}
623 );
624
625 assert_eq!(
626 client1
627 .request(client1_conn_id, proto::Test { id: 1 },)
628 .await
629 .unwrap(),
630 proto::Test { id: 1 }
631 );
632
633 assert_eq!(
634 client2
635 .request(client2_conn_id, proto::Test { id: 2 })
636 .await
637 .unwrap(),
638 proto::Test { id: 2 }
639 );
640
641 client1.disconnect(client1_conn_id);
642 client2.disconnect(client1_conn_id);
643
644 async fn handle_messages(
645 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
646 peer: Arc<Peer>,
647 ) -> Result<()> {
648 while let Some(envelope) = messages.next().await {
649 let envelope = envelope.into_any();
650 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
651 let receipt = envelope.receipt();
652 peer.respond(receipt, proto::Ack {})?
653 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
654 {
655 peer.respond(envelope.receipt(), envelope.payload.clone())?
656 } else {
657 panic!("unknown message type");
658 }
659 }
660
661 Ok(())
662 }
663 }
664
665 #[gpui::test(iterations = 50)]
666 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
667 let executor = cx.foreground();
668 let server = Peer::new(0);
669 let client = Peer::new(0);
670
671 let (client_to_server_conn, server_to_client_conn, _kill) =
672 Connection::in_memory(cx.background());
673 let (client_to_server_conn_id, io_task1, mut client_incoming) =
674 client.add_test_connection(client_to_server_conn, cx.background());
675 let (server_to_client_conn_id, io_task2, mut server_incoming) =
676 server.add_test_connection(server_to_client_conn, cx.background());
677
678 executor.spawn(io_task1).detach();
679 executor.spawn(io_task2).detach();
680
681 executor
682 .spawn(async move {
683 let request = server_incoming
684 .next()
685 .await
686 .unwrap()
687 .into_any()
688 .downcast::<TypedEnvelope<proto::Ping>>()
689 .unwrap();
690
691 server
692 .send(
693 server_to_client_conn_id,
694 proto::Error {
695 message: "message 1".to_string(),
696 },
697 )
698 .unwrap();
699 server
700 .send(
701 server_to_client_conn_id,
702 proto::Error {
703 message: "message 2".to_string(),
704 },
705 )
706 .unwrap();
707 server.respond(request.receipt(), proto::Ack {}).unwrap();
708
709 // Prevent the connection from being dropped
710 server_incoming.next().await;
711 })
712 .detach();
713
714 let events = Arc::new(Mutex::new(Vec::new()));
715
716 let response = client.request(client_to_server_conn_id, proto::Ping {});
717 let response_task = executor.spawn({
718 let events = events.clone();
719 async move {
720 response.await.unwrap();
721 events.lock().push("response".to_string());
722 }
723 });
724
725 executor
726 .spawn({
727 let events = events.clone();
728 async move {
729 let incoming1 = client_incoming
730 .next()
731 .await
732 .unwrap()
733 .into_any()
734 .downcast::<TypedEnvelope<proto::Error>>()
735 .unwrap();
736 events.lock().push(incoming1.payload.message);
737 let incoming2 = client_incoming
738 .next()
739 .await
740 .unwrap()
741 .into_any()
742 .downcast::<TypedEnvelope<proto::Error>>()
743 .unwrap();
744 events.lock().push(incoming2.payload.message);
745
746 // Prevent the connection from being dropped
747 client_incoming.next().await;
748 }
749 })
750 .detach();
751
752 response_task.await;
753 assert_eq!(
754 &*events.lock(),
755 &[
756 "message 1".to_string(),
757 "message 2".to_string(),
758 "response".to_string()
759 ]
760 );
761 }
762
763 #[gpui::test(iterations = 50)]
764 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
765 let executor = cx.foreground();
766 let server = Peer::new(0);
767 let client = Peer::new(0);
768
769 let (client_to_server_conn, server_to_client_conn, _kill) =
770 Connection::in_memory(cx.background());
771 let (client_to_server_conn_id, io_task1, mut client_incoming) =
772 client.add_test_connection(client_to_server_conn, cx.background());
773 let (server_to_client_conn_id, io_task2, mut server_incoming) =
774 server.add_test_connection(server_to_client_conn, cx.background());
775
776 executor.spawn(io_task1).detach();
777 executor.spawn(io_task2).detach();
778
779 executor
780 .spawn(async move {
781 let request1 = server_incoming
782 .next()
783 .await
784 .unwrap()
785 .into_any()
786 .downcast::<TypedEnvelope<proto::Ping>>()
787 .unwrap();
788 let request2 = server_incoming
789 .next()
790 .await
791 .unwrap()
792 .into_any()
793 .downcast::<TypedEnvelope<proto::Ping>>()
794 .unwrap();
795
796 server
797 .send(
798 server_to_client_conn_id,
799 proto::Error {
800 message: "message 1".to_string(),
801 },
802 )
803 .unwrap();
804 server
805 .send(
806 server_to_client_conn_id,
807 proto::Error {
808 message: "message 2".to_string(),
809 },
810 )
811 .unwrap();
812 server.respond(request1.receipt(), proto::Ack {}).unwrap();
813 server.respond(request2.receipt(), proto::Ack {}).unwrap();
814
815 // Prevent the connection from being dropped
816 server_incoming.next().await;
817 })
818 .detach();
819
820 let events = Arc::new(Mutex::new(Vec::new()));
821
822 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
823 let request1_task = executor.spawn(request1);
824 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
825 let request2_task = executor.spawn({
826 let events = events.clone();
827 async move {
828 request2.await.unwrap();
829 events.lock().push("response 2".to_string());
830 }
831 });
832
833 executor
834 .spawn({
835 let events = events.clone();
836 async move {
837 let incoming1 = client_incoming
838 .next()
839 .await
840 .unwrap()
841 .into_any()
842 .downcast::<TypedEnvelope<proto::Error>>()
843 .unwrap();
844 events.lock().push(incoming1.payload.message);
845 let incoming2 = client_incoming
846 .next()
847 .await
848 .unwrap()
849 .into_any()
850 .downcast::<TypedEnvelope<proto::Error>>()
851 .unwrap();
852 events.lock().push(incoming2.payload.message);
853
854 // Prevent the connection from being dropped
855 client_incoming.next().await;
856 }
857 })
858 .detach();
859
860 // Allow the request to make some progress before dropping it.
861 cx.background().simulate_random_delay().await;
862 drop(request1_task);
863
864 request2_task.await;
865 assert_eq!(
866 &*events.lock(),
867 &[
868 "message 1".to_string(),
869 "message 2".to_string(),
870 "response 2".to_string()
871 ]
872 );
873 }
874
875 #[gpui::test(iterations = 50)]
876 async fn test_disconnect(cx: &mut TestAppContext) {
877 let executor = cx.foreground();
878
879 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
880
881 let client = Peer::new(0);
882 let (connection_id, io_handler, mut incoming) =
883 client.add_test_connection(client_conn, cx.background());
884
885 let (io_ended_tx, io_ended_rx) = oneshot::channel();
886 executor
887 .spawn(async move {
888 io_handler.await.ok();
889 io_ended_tx.send(()).unwrap();
890 })
891 .detach();
892
893 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
894 executor
895 .spawn(async move {
896 incoming.next().await;
897 messages_ended_tx.send(()).unwrap();
898 })
899 .detach();
900
901 client.disconnect(connection_id);
902
903 let _ = io_ended_rx.await;
904 let _ = messages_ended_rx.await;
905 assert!(server_conn
906 .send(WebSocketMessage::Binary(vec![]))
907 .await
908 .is_err());
909 }
910
911 #[gpui::test(iterations = 50)]
912 async fn test_io_error(cx: &mut TestAppContext) {
913 let executor = cx.foreground();
914 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
915
916 let client = Peer::new(0);
917 let (connection_id, io_handler, mut incoming) =
918 client.add_test_connection(client_conn, cx.background());
919 executor.spawn(io_handler).detach();
920 executor
921 .spawn(async move { incoming.next().await })
922 .detach();
923
924 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
925 let _request = server_conn.rx.next().await.unwrap().unwrap();
926
927 drop(server_conn);
928 assert_eq!(
929 response.await.unwrap_err().to_string(),
930 "connection was closed"
931 );
932 }
933}