1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::BoxStream,
10 FutureExt, SinkExt, StreamExt,
11};
12use parking_lot::{Mutex, RwLock};
13use serde::{ser::SerializeStruct, Serialize};
14use std::{fmt, sync::atomic::Ordering::SeqCst};
15use std::{
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24use tracing::instrument;
25
26#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
27pub struct ConnectionId {
28 pub owner_id: u32,
29 pub id: u32,
30}
31
32impl Into<PeerId> for ConnectionId {
33 fn into(self) -> PeerId {
34 PeerId {
35 owner_id: self.owner_id,
36 id: self.id,
37 }
38 }
39}
40
41impl From<PeerId> for ConnectionId {
42 fn from(peer_id: PeerId) -> Self {
43 Self {
44 owner_id: peer_id.owner_id,
45 id: peer_id.id,
46 }
47 }
48}
49
50impl fmt::Display for ConnectionId {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 write!(f, "{}/{}", self.owner_id, self.id)
53 }
54}
55
56pub struct Receipt<T> {
57 pub sender_id: ConnectionId,
58 pub message_id: u32,
59 payload_type: PhantomData<T>,
60}
61
62impl<T> Clone for Receipt<T> {
63 fn clone(&self) -> Self {
64 Self {
65 sender_id: self.sender_id,
66 message_id: self.message_id,
67 payload_type: PhantomData,
68 }
69 }
70}
71
72impl<T> Copy for Receipt<T> {}
73
74pub struct TypedEnvelope<T> {
75 pub sender_id: ConnectionId,
76 pub original_sender_id: Option<PeerId>,
77 pub message_id: u32,
78 pub payload: T,
79}
80
81impl<T> TypedEnvelope<T> {
82 pub fn original_sender_id(&self) -> Result<PeerId> {
83 self.original_sender_id
84 .ok_or_else(|| anyhow!("missing original_sender_id"))
85 }
86}
87
88impl<T: RequestMessage> TypedEnvelope<T> {
89 pub fn receipt(&self) -> Receipt<T> {
90 Receipt {
91 sender_id: self.sender_id,
92 message_id: self.message_id,
93 payload_type: PhantomData,
94 }
95 }
96}
97
98pub struct Peer {
99 epoch: AtomicU32,
100 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
101 next_connection_id: AtomicU32,
102}
103
104#[derive(Clone, Serialize)]
105pub struct ConnectionState {
106 #[serde(skip)]
107 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
108 next_message_id: Arc<AtomicU32>,
109 #[allow(clippy::type_complexity)]
110 #[serde(skip)]
111 response_channels:
112 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
113}
114
115const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
116const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
117pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
118
119impl Peer {
120 pub fn new(epoch: u32) -> Arc<Self> {
121 Arc::new(Self {
122 epoch: AtomicU32::new(epoch),
123 connections: Default::default(),
124 next_connection_id: Default::default(),
125 })
126 }
127
128 pub fn epoch(&self) -> u32 {
129 self.epoch.load(SeqCst)
130 }
131
132 #[instrument(skip_all)]
133 pub fn add_connection<F, Fut, Out>(
134 self: &Arc<Self>,
135 connection: Connection,
136 create_timer: F,
137 ) -> (
138 ConnectionId,
139 impl Future<Output = anyhow::Result<()>> + Send,
140 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
141 )
142 where
143 F: Send + Fn(Duration) -> Fut,
144 Fut: Send + Future<Output = Out>,
145 Out: Send,
146 {
147 // For outgoing messages, use an unbounded channel so that application code
148 // can always send messages without yielding. For incoming messages, use a
149 // bounded channel so that other peers will receive backpressure if they send
150 // messages faster than this peer can process them.
151 #[cfg(any(test, feature = "test-support"))]
152 const INCOMING_BUFFER_SIZE: usize = 1;
153 #[cfg(not(any(test, feature = "test-support")))]
154 const INCOMING_BUFFER_SIZE: usize = 64;
155 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
156 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
157
158 let connection_id = ConnectionId {
159 owner_id: self.epoch.load(SeqCst),
160 id: self.next_connection_id.fetch_add(1, SeqCst),
161 };
162 let connection_state = ConnectionState {
163 outgoing_tx,
164 next_message_id: Default::default(),
165 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
166 };
167 let mut writer = MessageStream::new(connection.tx);
168 let mut reader = MessageStream::new(connection.rx);
169
170 let this = self.clone();
171 let response_channels = connection_state.response_channels.clone();
172 let handle_io = async move {
173 tracing::debug!(%connection_id, "handle io future: start");
174
175 let _end_connection = util::defer(|| {
176 response_channels.lock().take();
177 this.connections.write().remove(&connection_id);
178 tracing::debug!(%connection_id, "handle io future: end");
179 });
180
181 // Send messages on this frequency so the connection isn't closed.
182 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
183 futures::pin_mut!(keepalive_timer);
184
185 // Disconnect if we don't receive messages at least this frequently.
186 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
187 futures::pin_mut!(receive_timeout);
188
189 loop {
190 tracing::debug!(%connection_id, "outer loop iteration start");
191 let read_message = reader.read().fuse();
192 futures::pin_mut!(read_message);
193
194 loop {
195 tracing::debug!(%connection_id, "inner loop iteration start");
196 futures::select_biased! {
197 outgoing = outgoing_rx.next().fuse() => match outgoing {
198 Some(outgoing) => {
199 tracing::debug!(%connection_id, "outgoing rpc message: writing");
200 futures::select_biased! {
201 result = writer.write(outgoing).fuse() => {
202 tracing::debug!(%connection_id, "outgoing rpc message: done writing");
203 result.context("failed to write RPC message")?;
204 tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
205 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
206 }
207 _ = create_timer(WRITE_TIMEOUT).fuse() => {
208 tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
209 Err(anyhow!("timed out writing message"))?;
210 }
211 }
212 }
213 None => {
214 tracing::debug!(%connection_id, "outgoing rpc message: channel closed");
215 return Ok(())
216 },
217 },
218 _ = keepalive_timer => {
219 tracing::debug!(%connection_id, "keepalive interval: pinging");
220 futures::select_biased! {
221 result = writer.write(proto::Message::Ping).fuse() => {
222 tracing::debug!(%connection_id, "keepalive interval: done pinging");
223 result.context("failed to send keepalive")?;
224 tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
225 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
226 }
227 _ = create_timer(WRITE_TIMEOUT).fuse() => {
228 tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
229 Err(anyhow!("timed out sending keepalive"))?;
230 }
231 }
232 }
233 incoming = read_message => {
234 let incoming = incoming.context("error reading rpc message from socket")?;
235 tracing::debug!(%connection_id, "incoming rpc message: received");
236 tracing::debug!(%connection_id, "receive timeout: resetting");
237 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
238 if let proto::Message::Envelope(incoming) = incoming {
239 tracing::debug!(%connection_id, "incoming rpc message: processing");
240 futures::select_biased! {
241 result = incoming_tx.send(incoming).fuse() => match result {
242 Ok(_) => {
243 tracing::debug!(%connection_id, "incoming rpc message: processed");
244 }
245 Err(_) => {
246 tracing::debug!(%connection_id, "incoming rpc message: channel closed");
247 return Ok(())
248 }
249 },
250 _ = create_timer(WRITE_TIMEOUT).fuse() => {
251 tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
252 Err(anyhow!("timed out processing incoming message"))?
253 }
254 }
255 }
256 break;
257 },
258 _ = receive_timeout => {
259 tracing::debug!(%connection_id, "receive timeout: delay between messages too long");
260 Err(anyhow!("delay between messages too long"))?
261 }
262 }
263 }
264 }
265 };
266
267 let response_channels = connection_state.response_channels.clone();
268 self.connections
269 .write()
270 .insert(connection_id, connection_state);
271
272 let incoming_rx = incoming_rx.filter_map(move |incoming| {
273 let response_channels = response_channels.clone();
274 async move {
275 let message_id = incoming.id;
276 tracing::debug!(?incoming, "incoming message future: start");
277 let _end = util::defer(move || {
278 tracing::debug!(%connection_id, message_id, "incoming message future: end");
279 });
280
281 if let Some(responding_to) = incoming.responding_to {
282 tracing::debug!(
283 %connection_id,
284 message_id,
285 responding_to,
286 "incoming response: received"
287 );
288 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
289 if let Some(tx) = channel {
290 let requester_resumed = oneshot::channel();
291 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
292 tracing::debug!(
293 %connection_id,
294 message_id,
295 responding_to = responding_to,
296 ?error,
297 "incoming response: request future dropped",
298 );
299 }
300
301 tracing::debug!(
302 %connection_id,
303 message_id,
304 responding_to,
305 "incoming response: waiting to resume requester"
306 );
307 let _ = requester_resumed.1.await;
308 tracing::debug!(
309 %connection_id,
310 message_id,
311 responding_to,
312 "incoming response: requester resumed"
313 );
314 } else {
315 tracing::warn!(
316 %connection_id,
317 message_id,
318 responding_to,
319 "incoming response: unknown request"
320 );
321 }
322
323 None
324 } else {
325 tracing::debug!(%connection_id, message_id, "incoming message: received");
326 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
327 tracing::error!(
328 %connection_id,
329 message_id,
330 "unable to construct a typed envelope"
331 );
332 None
333 })
334 }
335 }
336 });
337 (connection_id, handle_io, incoming_rx.boxed())
338 }
339
340 #[cfg(any(test, feature = "test-support"))]
341 pub fn add_test_connection(
342 self: &Arc<Self>,
343 connection: Connection,
344 executor: Arc<gpui::executor::Background>,
345 ) -> (
346 ConnectionId,
347 impl Future<Output = anyhow::Result<()>> + Send,
348 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
349 ) {
350 let executor = executor.clone();
351 self.add_connection(connection, move |duration| executor.timer(duration))
352 }
353
354 pub fn disconnect(&self, connection_id: ConnectionId) {
355 self.connections.write().remove(&connection_id);
356 }
357
358 pub fn reset(&self, epoch: u32) {
359 self.teardown();
360 self.next_connection_id.store(0, SeqCst);
361 self.epoch.store(epoch, SeqCst);
362 }
363
364 pub fn teardown(&self) {
365 self.connections.write().clear();
366 }
367
368 pub fn request<T: RequestMessage>(
369 &self,
370 receiver_id: ConnectionId,
371 request: T,
372 ) -> impl Future<Output = Result<T::Response>> {
373 self.request_internal(None, receiver_id, request)
374 }
375
376 pub fn forward_request<T: RequestMessage>(
377 &self,
378 sender_id: ConnectionId,
379 receiver_id: ConnectionId,
380 request: T,
381 ) -> impl Future<Output = Result<T::Response>> {
382 self.request_internal(Some(sender_id), receiver_id, request)
383 }
384
385 pub fn request_internal<T: RequestMessage>(
386 &self,
387 original_sender_id: Option<ConnectionId>,
388 receiver_id: ConnectionId,
389 request: T,
390 ) -> impl Future<Output = Result<T::Response>> {
391 let (tx, rx) = oneshot::channel();
392 let send = self.connection_state(receiver_id).and_then(|connection| {
393 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
394 connection
395 .response_channels
396 .lock()
397 .as_mut()
398 .ok_or_else(|| anyhow!("connection was closed"))?
399 .insert(message_id, tx);
400 connection
401 .outgoing_tx
402 .unbounded_send(proto::Message::Envelope(request.into_envelope(
403 message_id,
404 None,
405 original_sender_id.map(Into::into),
406 )))
407 .map_err(|_| anyhow!("connection was closed"))?;
408 Ok(())
409 });
410 async move {
411 send?;
412 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
413 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
414 Err(anyhow!(
415 "RPC request {} failed - {}",
416 T::NAME,
417 error.message
418 ))
419 } else {
420 T::Response::from_envelope(response)
421 .ok_or_else(|| anyhow!("received response of the wrong type"))
422 }
423 }
424 }
425
426 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
427 let connection = self.connection_state(receiver_id)?;
428 let message_id = connection
429 .next_message_id
430 .fetch_add(1, atomic::Ordering::SeqCst);
431 connection
432 .outgoing_tx
433 .unbounded_send(proto::Message::Envelope(
434 message.into_envelope(message_id, None, None),
435 ))?;
436 Ok(())
437 }
438
439 pub fn forward_send<T: EnvelopedMessage>(
440 &self,
441 sender_id: ConnectionId,
442 receiver_id: ConnectionId,
443 message: T,
444 ) -> Result<()> {
445 let connection = self.connection_state(receiver_id)?;
446 let message_id = connection
447 .next_message_id
448 .fetch_add(1, atomic::Ordering::SeqCst);
449 connection
450 .outgoing_tx
451 .unbounded_send(proto::Message::Envelope(message.into_envelope(
452 message_id,
453 None,
454 Some(sender_id.into()),
455 )))?;
456 Ok(())
457 }
458
459 pub fn respond<T: RequestMessage>(
460 &self,
461 receipt: Receipt<T>,
462 response: T::Response,
463 ) -> Result<()> {
464 let connection = self.connection_state(receipt.sender_id)?;
465 let message_id = connection
466 .next_message_id
467 .fetch_add(1, atomic::Ordering::SeqCst);
468 connection
469 .outgoing_tx
470 .unbounded_send(proto::Message::Envelope(response.into_envelope(
471 message_id,
472 Some(receipt.message_id),
473 None,
474 )))?;
475 Ok(())
476 }
477
478 pub fn respond_with_error<T: RequestMessage>(
479 &self,
480 receipt: Receipt<T>,
481 response: proto::Error,
482 ) -> Result<()> {
483 let connection = self.connection_state(receipt.sender_id)?;
484 let message_id = connection
485 .next_message_id
486 .fetch_add(1, atomic::Ordering::SeqCst);
487 connection
488 .outgoing_tx
489 .unbounded_send(proto::Message::Envelope(response.into_envelope(
490 message_id,
491 Some(receipt.message_id),
492 None,
493 )))?;
494 Ok(())
495 }
496
497 pub fn respond_with_unhandled_message(
498 &self,
499 envelope: Box<dyn AnyTypedEnvelope>,
500 ) -> Result<()> {
501 let connection = self.connection_state(envelope.sender_id())?;
502 let response = proto::Error {
503 message: format!("message {} was not handled", envelope.payload_type_name()),
504 };
505 let message_id = connection
506 .next_message_id
507 .fetch_add(1, atomic::Ordering::SeqCst);
508 connection
509 .outgoing_tx
510 .unbounded_send(proto::Message::Envelope(response.into_envelope(
511 message_id,
512 Some(envelope.message_id()),
513 None,
514 )))?;
515 Ok(())
516 }
517
518 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
519 let connections = self.connections.read();
520 let connection = connections
521 .get(&connection_id)
522 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
523 Ok(connection.clone())
524 }
525}
526
527impl Serialize for Peer {
528 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
529 where
530 S: serde::Serializer,
531 {
532 let mut state = serializer.serialize_struct("Peer", 2)?;
533 state.serialize_field("connections", &*self.connections.read())?;
534 state.end()
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use super::*;
541 use crate::TypedEnvelope;
542 use async_tungstenite::tungstenite::Message as WebSocketMessage;
543 use gpui::TestAppContext;
544
545 #[ctor::ctor]
546 fn init_logger() {
547 if std::env::var("RUST_LOG").is_ok() {
548 env_logger::init();
549 }
550 }
551
552 #[gpui::test(iterations = 50)]
553 async fn test_request_response(cx: &mut TestAppContext) {
554 let executor = cx.foreground();
555
556 // create 2 clients connected to 1 server
557 let server = Peer::new(0);
558 let client1 = Peer::new(0);
559 let client2 = Peer::new(0);
560
561 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
562 Connection::in_memory(cx.background());
563 let (client1_conn_id, io_task1, client1_incoming) =
564 client1.add_test_connection(client1_to_server_conn, cx.background());
565 let (_, io_task2, server_incoming1) =
566 server.add_test_connection(server_to_client_1_conn, cx.background());
567
568 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
569 Connection::in_memory(cx.background());
570 let (client2_conn_id, io_task3, client2_incoming) =
571 client2.add_test_connection(client2_to_server_conn, cx.background());
572 let (_, io_task4, server_incoming2) =
573 server.add_test_connection(server_to_client_2_conn, cx.background());
574
575 executor.spawn(io_task1).detach();
576 executor.spawn(io_task2).detach();
577 executor.spawn(io_task3).detach();
578 executor.spawn(io_task4).detach();
579 executor
580 .spawn(handle_messages(server_incoming1, server.clone()))
581 .detach();
582 executor
583 .spawn(handle_messages(client1_incoming, client1.clone()))
584 .detach();
585 executor
586 .spawn(handle_messages(server_incoming2, server.clone()))
587 .detach();
588 executor
589 .spawn(handle_messages(client2_incoming, client2.clone()))
590 .detach();
591
592 assert_eq!(
593 client1
594 .request(client1_conn_id, proto::Ping {},)
595 .await
596 .unwrap(),
597 proto::Ack {}
598 );
599
600 assert_eq!(
601 client2
602 .request(client2_conn_id, proto::Ping {},)
603 .await
604 .unwrap(),
605 proto::Ack {}
606 );
607
608 assert_eq!(
609 client1
610 .request(client1_conn_id, proto::Test { id: 1 },)
611 .await
612 .unwrap(),
613 proto::Test { id: 1 }
614 );
615
616 assert_eq!(
617 client2
618 .request(client2_conn_id, proto::Test { id: 2 })
619 .await
620 .unwrap(),
621 proto::Test { id: 2 }
622 );
623
624 client1.disconnect(client1_conn_id);
625 client2.disconnect(client1_conn_id);
626
627 async fn handle_messages(
628 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
629 peer: Arc<Peer>,
630 ) -> Result<()> {
631 while let Some(envelope) = messages.next().await {
632 let envelope = envelope.into_any();
633 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
634 let receipt = envelope.receipt();
635 peer.respond(receipt, proto::Ack {})?
636 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
637 {
638 peer.respond(envelope.receipt(), envelope.payload.clone())?
639 } else {
640 panic!("unknown message type");
641 }
642 }
643
644 Ok(())
645 }
646 }
647
648 #[gpui::test(iterations = 50)]
649 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
650 let executor = cx.foreground();
651 let server = Peer::new(0);
652 let client = Peer::new(0);
653
654 let (client_to_server_conn, server_to_client_conn, _kill) =
655 Connection::in_memory(cx.background());
656 let (client_to_server_conn_id, io_task1, mut client_incoming) =
657 client.add_test_connection(client_to_server_conn, cx.background());
658 let (server_to_client_conn_id, io_task2, mut server_incoming) =
659 server.add_test_connection(server_to_client_conn, cx.background());
660
661 executor.spawn(io_task1).detach();
662 executor.spawn(io_task2).detach();
663
664 executor
665 .spawn(async move {
666 let request = server_incoming
667 .next()
668 .await
669 .unwrap()
670 .into_any()
671 .downcast::<TypedEnvelope<proto::Ping>>()
672 .unwrap();
673
674 server
675 .send(
676 server_to_client_conn_id,
677 proto::Error {
678 message: "message 1".to_string(),
679 },
680 )
681 .unwrap();
682 server
683 .send(
684 server_to_client_conn_id,
685 proto::Error {
686 message: "message 2".to_string(),
687 },
688 )
689 .unwrap();
690 server.respond(request.receipt(), proto::Ack {}).unwrap();
691
692 // Prevent the connection from being dropped
693 server_incoming.next().await;
694 })
695 .detach();
696
697 let events = Arc::new(Mutex::new(Vec::new()));
698
699 let response = client.request(client_to_server_conn_id, proto::Ping {});
700 let response_task = executor.spawn({
701 let events = events.clone();
702 async move {
703 response.await.unwrap();
704 events.lock().push("response".to_string());
705 }
706 });
707
708 executor
709 .spawn({
710 let events = events.clone();
711 async move {
712 let incoming1 = client_incoming
713 .next()
714 .await
715 .unwrap()
716 .into_any()
717 .downcast::<TypedEnvelope<proto::Error>>()
718 .unwrap();
719 events.lock().push(incoming1.payload.message);
720 let incoming2 = client_incoming
721 .next()
722 .await
723 .unwrap()
724 .into_any()
725 .downcast::<TypedEnvelope<proto::Error>>()
726 .unwrap();
727 events.lock().push(incoming2.payload.message);
728
729 // Prevent the connection from being dropped
730 client_incoming.next().await;
731 }
732 })
733 .detach();
734
735 response_task.await;
736 assert_eq!(
737 &*events.lock(),
738 &[
739 "message 1".to_string(),
740 "message 2".to_string(),
741 "response".to_string()
742 ]
743 );
744 }
745
746 #[gpui::test(iterations = 50)]
747 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
748 let executor = cx.foreground();
749 let server = Peer::new(0);
750 let client = Peer::new(0);
751
752 let (client_to_server_conn, server_to_client_conn, _kill) =
753 Connection::in_memory(cx.background());
754 let (client_to_server_conn_id, io_task1, mut client_incoming) =
755 client.add_test_connection(client_to_server_conn, cx.background());
756 let (server_to_client_conn_id, io_task2, mut server_incoming) =
757 server.add_test_connection(server_to_client_conn, cx.background());
758
759 executor.spawn(io_task1).detach();
760 executor.spawn(io_task2).detach();
761
762 executor
763 .spawn(async move {
764 let request1 = server_incoming
765 .next()
766 .await
767 .unwrap()
768 .into_any()
769 .downcast::<TypedEnvelope<proto::Ping>>()
770 .unwrap();
771 let request2 = server_incoming
772 .next()
773 .await
774 .unwrap()
775 .into_any()
776 .downcast::<TypedEnvelope<proto::Ping>>()
777 .unwrap();
778
779 server
780 .send(
781 server_to_client_conn_id,
782 proto::Error {
783 message: "message 1".to_string(),
784 },
785 )
786 .unwrap();
787 server
788 .send(
789 server_to_client_conn_id,
790 proto::Error {
791 message: "message 2".to_string(),
792 },
793 )
794 .unwrap();
795 server.respond(request1.receipt(), proto::Ack {}).unwrap();
796 server.respond(request2.receipt(), proto::Ack {}).unwrap();
797
798 // Prevent the connection from being dropped
799 server_incoming.next().await;
800 })
801 .detach();
802
803 let events = Arc::new(Mutex::new(Vec::new()));
804
805 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
806 let request1_task = executor.spawn(request1);
807 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
808 let request2_task = executor.spawn({
809 let events = events.clone();
810 async move {
811 request2.await.unwrap();
812 events.lock().push("response 2".to_string());
813 }
814 });
815
816 executor
817 .spawn({
818 let events = events.clone();
819 async move {
820 let incoming1 = client_incoming
821 .next()
822 .await
823 .unwrap()
824 .into_any()
825 .downcast::<TypedEnvelope<proto::Error>>()
826 .unwrap();
827 events.lock().push(incoming1.payload.message);
828 let incoming2 = client_incoming
829 .next()
830 .await
831 .unwrap()
832 .into_any()
833 .downcast::<TypedEnvelope<proto::Error>>()
834 .unwrap();
835 events.lock().push(incoming2.payload.message);
836
837 // Prevent the connection from being dropped
838 client_incoming.next().await;
839 }
840 })
841 .detach();
842
843 // Allow the request to make some progress before dropping it.
844 cx.background().simulate_random_delay().await;
845 drop(request1_task);
846
847 request2_task.await;
848 assert_eq!(
849 &*events.lock(),
850 &[
851 "message 1".to_string(),
852 "message 2".to_string(),
853 "response 2".to_string()
854 ]
855 );
856 }
857
858 #[gpui::test(iterations = 50)]
859 async fn test_disconnect(cx: &mut TestAppContext) {
860 let executor = cx.foreground();
861
862 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
863
864 let client = Peer::new(0);
865 let (connection_id, io_handler, mut incoming) =
866 client.add_test_connection(client_conn, cx.background());
867
868 let (io_ended_tx, io_ended_rx) = oneshot::channel();
869 executor
870 .spawn(async move {
871 io_handler.await.ok();
872 io_ended_tx.send(()).unwrap();
873 })
874 .detach();
875
876 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
877 executor
878 .spawn(async move {
879 incoming.next().await;
880 messages_ended_tx.send(()).unwrap();
881 })
882 .detach();
883
884 client.disconnect(connection_id);
885
886 let _ = io_ended_rx.await;
887 let _ = messages_ended_rx.await;
888 assert!(server_conn
889 .send(WebSocketMessage::Binary(vec![]))
890 .await
891 .is_err());
892 }
893
894 #[gpui::test(iterations = 50)]
895 async fn test_io_error(cx: &mut TestAppContext) {
896 let executor = cx.foreground();
897 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
898
899 let client = Peer::new(0);
900 let (connection_id, io_handler, mut incoming) =
901 client.add_test_connection(client_conn, cx.background());
902 executor.spawn(io_handler).detach();
903 executor
904 .spawn(async move { incoming.next().await })
905 .detach();
906
907 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
908 let _request = server_conn.rx.next().await.unwrap().unwrap();
909
910 drop(server_conn);
911 assert_eq!(
912 response.await.unwrap_err().to_string(),
913 "connection was closed"
914 );
915 }
916}