1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::BoxStream,
10 FutureExt, SinkExt, StreamExt,
11};
12use parking_lot::{Mutex, RwLock};
13use serde::{ser::SerializeStruct, Serialize};
14use std::{fmt, sync::atomic::Ordering::SeqCst};
15use std::{
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24use tracing::instrument;
25
26#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
27pub struct ConnectionId {
28 pub owner_id: u32,
29 pub id: u32,
30}
31
32impl Into<PeerId> for ConnectionId {
33 fn into(self) -> PeerId {
34 PeerId {
35 owner_id: self.owner_id,
36 id: self.id,
37 }
38 }
39}
40
41impl From<PeerId> for ConnectionId {
42 fn from(peer_id: PeerId) -> Self {
43 Self {
44 owner_id: peer_id.owner_id,
45 id: peer_id.id,
46 }
47 }
48}
49
50impl fmt::Display for ConnectionId {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 write!(f, "{}/{}", self.owner_id, self.id)
53 }
54}
55
56pub struct Receipt<T> {
57 pub sender_id: ConnectionId,
58 pub message_id: u32,
59 payload_type: PhantomData<T>,
60}
61
62impl<T> Clone for Receipt<T> {
63 fn clone(&self) -> Self {
64 Self {
65 sender_id: self.sender_id,
66 message_id: self.message_id,
67 payload_type: PhantomData,
68 }
69 }
70}
71
72impl<T> Copy for Receipt<T> {}
73
74pub struct TypedEnvelope<T> {
75 pub sender_id: ConnectionId,
76 pub original_sender_id: Option<PeerId>,
77 pub message_id: u32,
78 pub payload: T,
79}
80
81impl<T> TypedEnvelope<T> {
82 pub fn original_sender_id(&self) -> Result<PeerId> {
83 self.original_sender_id
84 .clone()
85 .ok_or_else(|| anyhow!("missing original_sender_id"))
86 }
87}
88
89impl<T: RequestMessage> TypedEnvelope<T> {
90 pub fn receipt(&self) -> Receipt<T> {
91 Receipt {
92 sender_id: self.sender_id,
93 message_id: self.message_id,
94 payload_type: PhantomData,
95 }
96 }
97}
98
99pub struct Peer {
100 epoch: AtomicU32,
101 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
102 next_connection_id: AtomicU32,
103}
104
105#[derive(Clone, Serialize)]
106pub struct ConnectionState {
107 #[serde(skip)]
108 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
109 next_message_id: Arc<AtomicU32>,
110 #[allow(clippy::type_complexity)]
111 #[serde(skip)]
112 response_channels:
113 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
114}
115
116const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
117const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
118pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
119
120impl Peer {
121 pub fn new(epoch: u32) -> Arc<Self> {
122 Arc::new(Self {
123 epoch: AtomicU32::new(epoch),
124 connections: Default::default(),
125 next_connection_id: Default::default(),
126 })
127 }
128
129 pub fn epoch(&self) -> u32 {
130 self.epoch.load(SeqCst)
131 }
132
133 #[instrument(skip_all)]
134 pub fn add_connection<F, Fut, Out>(
135 self: &Arc<Self>,
136 connection: Connection,
137 create_timer: F,
138 ) -> (
139 ConnectionId,
140 impl Future<Output = anyhow::Result<()>> + Send,
141 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
142 )
143 where
144 F: Send + Fn(Duration) -> Fut,
145 Fut: Send + Future<Output = Out>,
146 Out: Send,
147 {
148 // For outgoing messages, use an unbounded channel so that application code
149 // can always send messages without yielding. For incoming messages, use a
150 // bounded channel so that other peers will receive backpressure if they send
151 // messages faster than this peer can process them.
152 #[cfg(any(test, feature = "test-support"))]
153 const INCOMING_BUFFER_SIZE: usize = 1;
154 #[cfg(not(any(test, feature = "test-support")))]
155 const INCOMING_BUFFER_SIZE: usize = 64;
156 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
157 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
158
159 let connection_id = ConnectionId {
160 owner_id: self.epoch.load(SeqCst),
161 id: self.next_connection_id.fetch_add(1, SeqCst),
162 };
163 let connection_state = ConnectionState {
164 outgoing_tx,
165 next_message_id: Default::default(),
166 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
167 };
168 let mut writer = MessageStream::new(connection.tx);
169 let mut reader = MessageStream::new(connection.rx);
170
171 let this = self.clone();
172 let response_channels = connection_state.response_channels.clone();
173 let handle_io = async move {
174 tracing::debug!(?connection_id, "handle io future: start");
175
176 let _end_connection = util::defer(|| {
177 response_channels.lock().take();
178 this.connections.write().remove(&connection_id);
179 tracing::debug!(?connection_id, "handle io future: end");
180 });
181
182 // Send messages on this frequency so the connection isn't closed.
183 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
184 futures::pin_mut!(keepalive_timer);
185
186 // Disconnect if we don't receive messages at least this frequently.
187 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
188 futures::pin_mut!(receive_timeout);
189
190 loop {
191 tracing::debug!(?connection_id, "outer loop iteration start");
192 let read_message = reader.read().fuse();
193 futures::pin_mut!(read_message);
194
195 loop {
196 tracing::debug!(?connection_id, "inner loop iteration start");
197 futures::select_biased! {
198 outgoing = outgoing_rx.next().fuse() => match outgoing {
199 Some(outgoing) => {
200 tracing::debug!(?connection_id, "outgoing rpc message: writing");
201 futures::select_biased! {
202 result = writer.write(outgoing).fuse() => {
203 tracing::debug!(?connection_id, "outgoing rpc message: done writing");
204 result.context("failed to write RPC message")?;
205 tracing::debug!(?connection_id, "keepalive interval: resetting after sending message");
206 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
207 }
208 _ = create_timer(WRITE_TIMEOUT).fuse() => {
209 tracing::debug!(?connection_id, "outgoing rpc message: writing timed out");
210 Err(anyhow!("timed out writing message"))?;
211 }
212 }
213 }
214 None => {
215 tracing::debug!(?connection_id, "outgoing rpc message: channel closed");
216 return Ok(())
217 },
218 },
219 _ = keepalive_timer => {
220 tracing::debug!(?connection_id, "keepalive interval: pinging");
221 futures::select_biased! {
222 result = writer.write(proto::Message::Ping).fuse() => {
223 tracing::debug!(?connection_id, "keepalive interval: done pinging");
224 result.context("failed to send keepalive")?;
225 tracing::debug!(?connection_id, "keepalive interval: resetting after pinging");
226 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
227 }
228 _ = create_timer(WRITE_TIMEOUT).fuse() => {
229 tracing::debug!(?connection_id, "keepalive interval: pinging timed out");
230 Err(anyhow!("timed out sending keepalive"))?;
231 }
232 }
233 }
234 incoming = read_message => {
235 let incoming = incoming.context("error reading rpc message from socket")?;
236 tracing::debug!(?connection_id, "incoming rpc message: received");
237 tracing::debug!(?connection_id, "receive timeout: resetting");
238 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
239 if let proto::Message::Envelope(incoming) = incoming {
240 tracing::debug!(?connection_id, "incoming rpc message: processing");
241 futures::select_biased! {
242 result = incoming_tx.send(incoming).fuse() => match result {
243 Ok(_) => {
244 tracing::debug!(?connection_id, "incoming rpc message: processed");
245 }
246 Err(_) => {
247 tracing::debug!(?connection_id, "incoming rpc message: channel closed");
248 return Ok(())
249 }
250 },
251 _ = create_timer(WRITE_TIMEOUT).fuse() => {
252 tracing::debug!(?connection_id, "incoming rpc message: processing timed out");
253 Err(anyhow!("timed out processing incoming message"))?
254 }
255 }
256 }
257 break;
258 },
259 _ = receive_timeout => {
260 tracing::debug!(?connection_id, "receive timeout: delay between messages too long");
261 Err(anyhow!("delay between messages too long"))?
262 }
263 }
264 }
265 }
266 };
267
268 let response_channels = connection_state.response_channels.clone();
269 self.connections
270 .write()
271 .insert(connection_id, connection_state);
272
273 let incoming_rx = incoming_rx.filter_map(move |incoming| {
274 let response_channels = response_channels.clone();
275 async move {
276 let message_id = incoming.id;
277 tracing::debug!(?incoming, "incoming message future: start");
278 let _end = util::defer(move || {
279 tracing::debug!(?connection_id, message_id, "incoming message future: end");
280 });
281
282 if let Some(responding_to) = incoming.responding_to {
283 tracing::debug!(
284 ?connection_id,
285 message_id,
286 responding_to,
287 "incoming response: received"
288 );
289 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
290 if let Some(tx) = channel {
291 let requester_resumed = oneshot::channel();
292 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
293 tracing::debug!(
294 ?connection_id,
295 message_id,
296 responding_to = responding_to,
297 ?error,
298 "incoming response: request future dropped",
299 );
300 }
301
302 tracing::debug!(
303 ?connection_id,
304 message_id,
305 responding_to,
306 "incoming response: waiting to resume requester"
307 );
308 let _ = requester_resumed.1.await;
309 tracing::debug!(
310 ?connection_id,
311 message_id,
312 responding_to,
313 "incoming response: requester resumed"
314 );
315 } else {
316 tracing::warn!(
317 ?connection_id,
318 message_id,
319 responding_to,
320 "incoming response: unknown request"
321 );
322 }
323
324 None
325 } else {
326 tracing::debug!(?connection_id, message_id, "incoming message: received");
327 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
328 tracing::error!(
329 ?connection_id,
330 message_id,
331 "unable to construct a typed envelope"
332 );
333 None
334 })
335 }
336 }
337 });
338 (connection_id, handle_io, incoming_rx.boxed())
339 }
340
341 #[cfg(any(test, feature = "test-support"))]
342 pub fn add_test_connection(
343 self: &Arc<Self>,
344 connection: Connection,
345 executor: Arc<gpui::executor::Background>,
346 ) -> (
347 ConnectionId,
348 impl Future<Output = anyhow::Result<()>> + Send,
349 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
350 ) {
351 let executor = executor.clone();
352 self.add_connection(connection, move |duration| executor.timer(duration))
353 }
354
355 pub fn disconnect(&self, connection_id: ConnectionId) {
356 self.connections.write().remove(&connection_id);
357 }
358
359 pub fn reset(&self, epoch: u32) {
360 self.teardown();
361 self.next_connection_id.store(0, SeqCst);
362 self.epoch.store(epoch, SeqCst);
363 }
364
365 pub fn teardown(&self) {
366 self.connections.write().clear();
367 }
368
369 pub fn request<T: RequestMessage>(
370 &self,
371 receiver_id: ConnectionId,
372 request: T,
373 ) -> impl Future<Output = Result<T::Response>> {
374 self.request_internal(None, receiver_id, request)
375 }
376
377 pub fn forward_request<T: RequestMessage>(
378 &self,
379 sender_id: ConnectionId,
380 receiver_id: ConnectionId,
381 request: T,
382 ) -> impl Future<Output = Result<T::Response>> {
383 self.request_internal(Some(sender_id), receiver_id, request)
384 }
385
386 pub fn request_internal<T: RequestMessage>(
387 &self,
388 original_sender_id: Option<ConnectionId>,
389 receiver_id: ConnectionId,
390 request: T,
391 ) -> impl Future<Output = Result<T::Response>> {
392 let (tx, rx) = oneshot::channel();
393 let send = self.connection_state(receiver_id).and_then(|connection| {
394 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
395 connection
396 .response_channels
397 .lock()
398 .as_mut()
399 .ok_or_else(|| anyhow!("connection was closed"))?
400 .insert(message_id, tx);
401 connection
402 .outgoing_tx
403 .unbounded_send(proto::Message::Envelope(request.into_envelope(
404 message_id,
405 None,
406 original_sender_id.map(Into::into),
407 )))
408 .map_err(|_| anyhow!("connection was closed"))?;
409 Ok(())
410 });
411 async move {
412 send?;
413 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
414 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
415 Err(anyhow!(
416 "RPC request {} failed - {}",
417 T::NAME,
418 error.message
419 ))
420 } else {
421 T::Response::from_envelope(response)
422 .ok_or_else(|| anyhow!("received response of the wrong type"))
423 }
424 }
425 }
426
427 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
428 let connection = self.connection_state(receiver_id)?;
429 let message_id = connection
430 .next_message_id
431 .fetch_add(1, atomic::Ordering::SeqCst);
432 connection
433 .outgoing_tx
434 .unbounded_send(proto::Message::Envelope(
435 message.into_envelope(message_id, None, None),
436 ))?;
437 Ok(())
438 }
439
440 pub fn forward_send<T: EnvelopedMessage>(
441 &self,
442 sender_id: ConnectionId,
443 receiver_id: ConnectionId,
444 message: T,
445 ) -> Result<()> {
446 let connection = self.connection_state(receiver_id)?;
447 let message_id = connection
448 .next_message_id
449 .fetch_add(1, atomic::Ordering::SeqCst);
450 connection
451 .outgoing_tx
452 .unbounded_send(proto::Message::Envelope(message.into_envelope(
453 message_id,
454 None,
455 Some(sender_id.into()),
456 )))?;
457 Ok(())
458 }
459
460 pub fn respond<T: RequestMessage>(
461 &self,
462 receipt: Receipt<T>,
463 response: T::Response,
464 ) -> Result<()> {
465 let connection = self.connection_state(receipt.sender_id)?;
466 let message_id = connection
467 .next_message_id
468 .fetch_add(1, atomic::Ordering::SeqCst);
469 connection
470 .outgoing_tx
471 .unbounded_send(proto::Message::Envelope(response.into_envelope(
472 message_id,
473 Some(receipt.message_id),
474 None,
475 )))?;
476 Ok(())
477 }
478
479 pub fn respond_with_error<T: RequestMessage>(
480 &self,
481 receipt: Receipt<T>,
482 response: proto::Error,
483 ) -> Result<()> {
484 let connection = self.connection_state(receipt.sender_id)?;
485 let message_id = connection
486 .next_message_id
487 .fetch_add(1, atomic::Ordering::SeqCst);
488 connection
489 .outgoing_tx
490 .unbounded_send(proto::Message::Envelope(response.into_envelope(
491 message_id,
492 Some(receipt.message_id),
493 None,
494 )))?;
495 Ok(())
496 }
497
498 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
499 let connections = self.connections.read();
500 let connection = connections
501 .get(&connection_id)
502 .ok_or_else(|| anyhow!("no such connection: {:?}", connection_id))?;
503 Ok(connection.clone())
504 }
505}
506
507impl Serialize for Peer {
508 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
509 where
510 S: serde::Serializer,
511 {
512 let mut state = serializer.serialize_struct("Peer", 2)?;
513 state.serialize_field("connections", &*self.connections.read())?;
514 state.end()
515 }
516}
517
518#[cfg(test)]
519mod tests {
520 use super::*;
521 use crate::TypedEnvelope;
522 use async_tungstenite::tungstenite::Message as WebSocketMessage;
523 use gpui::TestAppContext;
524
525 #[ctor::ctor]
526 fn init_logger() {
527 if std::env::var("RUST_LOG").is_ok() {
528 env_logger::init();
529 }
530 }
531
532 #[gpui::test(iterations = 50)]
533 async fn test_request_response(cx: &mut TestAppContext) {
534 let executor = cx.foreground();
535
536 // create 2 clients connected to 1 server
537 let server = Peer::new(0);
538 let client1 = Peer::new(0);
539 let client2 = Peer::new(0);
540
541 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
542 Connection::in_memory(cx.background());
543 let (client1_conn_id, io_task1, client1_incoming) =
544 client1.add_test_connection(client1_to_server_conn, cx.background());
545 let (_, io_task2, server_incoming1) =
546 server.add_test_connection(server_to_client_1_conn, cx.background());
547
548 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
549 Connection::in_memory(cx.background());
550 let (client2_conn_id, io_task3, client2_incoming) =
551 client2.add_test_connection(client2_to_server_conn, cx.background());
552 let (_, io_task4, server_incoming2) =
553 server.add_test_connection(server_to_client_2_conn, cx.background());
554
555 executor.spawn(io_task1).detach();
556 executor.spawn(io_task2).detach();
557 executor.spawn(io_task3).detach();
558 executor.spawn(io_task4).detach();
559 executor
560 .spawn(handle_messages(server_incoming1, server.clone()))
561 .detach();
562 executor
563 .spawn(handle_messages(client1_incoming, client1.clone()))
564 .detach();
565 executor
566 .spawn(handle_messages(server_incoming2, server.clone()))
567 .detach();
568 executor
569 .spawn(handle_messages(client2_incoming, client2.clone()))
570 .detach();
571
572 assert_eq!(
573 client1
574 .request(client1_conn_id, proto::Ping {},)
575 .await
576 .unwrap(),
577 proto::Ack {}
578 );
579
580 assert_eq!(
581 client2
582 .request(client2_conn_id, proto::Ping {},)
583 .await
584 .unwrap(),
585 proto::Ack {}
586 );
587
588 assert_eq!(
589 client1
590 .request(client1_conn_id, proto::Test { id: 1 },)
591 .await
592 .unwrap(),
593 proto::Test { id: 1 }
594 );
595
596 assert_eq!(
597 client2
598 .request(client2_conn_id, proto::Test { id: 2 })
599 .await
600 .unwrap(),
601 proto::Test { id: 2 }
602 );
603
604 client1.disconnect(client1_conn_id);
605 client2.disconnect(client1_conn_id);
606
607 async fn handle_messages(
608 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
609 peer: Arc<Peer>,
610 ) -> Result<()> {
611 while let Some(envelope) = messages.next().await {
612 let envelope = envelope.into_any();
613 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
614 let receipt = envelope.receipt();
615 peer.respond(receipt, proto::Ack {})?
616 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
617 {
618 peer.respond(envelope.receipt(), envelope.payload.clone())?
619 } else {
620 panic!("unknown message type");
621 }
622 }
623
624 Ok(())
625 }
626 }
627
628 #[gpui::test(iterations = 50)]
629 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
630 let executor = cx.foreground();
631 let server = Peer::new(0);
632 let client = Peer::new(0);
633
634 let (client_to_server_conn, server_to_client_conn, _kill) =
635 Connection::in_memory(cx.background());
636 let (client_to_server_conn_id, io_task1, mut client_incoming) =
637 client.add_test_connection(client_to_server_conn, cx.background());
638 let (server_to_client_conn_id, io_task2, mut server_incoming) =
639 server.add_test_connection(server_to_client_conn, cx.background());
640
641 executor.spawn(io_task1).detach();
642 executor.spawn(io_task2).detach();
643
644 executor
645 .spawn(async move {
646 let request = server_incoming
647 .next()
648 .await
649 .unwrap()
650 .into_any()
651 .downcast::<TypedEnvelope<proto::Ping>>()
652 .unwrap();
653
654 server
655 .send(
656 server_to_client_conn_id,
657 proto::Error {
658 message: "message 1".to_string(),
659 },
660 )
661 .unwrap();
662 server
663 .send(
664 server_to_client_conn_id,
665 proto::Error {
666 message: "message 2".to_string(),
667 },
668 )
669 .unwrap();
670 server.respond(request.receipt(), proto::Ack {}).unwrap();
671
672 // Prevent the connection from being dropped
673 server_incoming.next().await;
674 })
675 .detach();
676
677 let events = Arc::new(Mutex::new(Vec::new()));
678
679 let response = client.request(client_to_server_conn_id, proto::Ping {});
680 let response_task = executor.spawn({
681 let events = events.clone();
682 async move {
683 response.await.unwrap();
684 events.lock().push("response".to_string());
685 }
686 });
687
688 executor
689 .spawn({
690 let events = events.clone();
691 async move {
692 let incoming1 = client_incoming
693 .next()
694 .await
695 .unwrap()
696 .into_any()
697 .downcast::<TypedEnvelope<proto::Error>>()
698 .unwrap();
699 events.lock().push(incoming1.payload.message);
700 let incoming2 = client_incoming
701 .next()
702 .await
703 .unwrap()
704 .into_any()
705 .downcast::<TypedEnvelope<proto::Error>>()
706 .unwrap();
707 events.lock().push(incoming2.payload.message);
708
709 // Prevent the connection from being dropped
710 client_incoming.next().await;
711 }
712 })
713 .detach();
714
715 response_task.await;
716 assert_eq!(
717 &*events.lock(),
718 &[
719 "message 1".to_string(),
720 "message 2".to_string(),
721 "response".to_string()
722 ]
723 );
724 }
725
726 #[gpui::test(iterations = 50)]
727 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
728 let executor = cx.foreground();
729 let server = Peer::new(0);
730 let client = Peer::new(0);
731
732 let (client_to_server_conn, server_to_client_conn, _kill) =
733 Connection::in_memory(cx.background());
734 let (client_to_server_conn_id, io_task1, mut client_incoming) =
735 client.add_test_connection(client_to_server_conn, cx.background());
736 let (server_to_client_conn_id, io_task2, mut server_incoming) =
737 server.add_test_connection(server_to_client_conn, cx.background());
738
739 executor.spawn(io_task1).detach();
740 executor.spawn(io_task2).detach();
741
742 executor
743 .spawn(async move {
744 let request1 = server_incoming
745 .next()
746 .await
747 .unwrap()
748 .into_any()
749 .downcast::<TypedEnvelope<proto::Ping>>()
750 .unwrap();
751 let request2 = server_incoming
752 .next()
753 .await
754 .unwrap()
755 .into_any()
756 .downcast::<TypedEnvelope<proto::Ping>>()
757 .unwrap();
758
759 server
760 .send(
761 server_to_client_conn_id,
762 proto::Error {
763 message: "message 1".to_string(),
764 },
765 )
766 .unwrap();
767 server
768 .send(
769 server_to_client_conn_id,
770 proto::Error {
771 message: "message 2".to_string(),
772 },
773 )
774 .unwrap();
775 server.respond(request1.receipt(), proto::Ack {}).unwrap();
776 server.respond(request2.receipt(), proto::Ack {}).unwrap();
777
778 // Prevent the connection from being dropped
779 server_incoming.next().await;
780 })
781 .detach();
782
783 let events = Arc::new(Mutex::new(Vec::new()));
784
785 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
786 let request1_task = executor.spawn(request1);
787 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
788 let request2_task = executor.spawn({
789 let events = events.clone();
790 async move {
791 request2.await.unwrap();
792 events.lock().push("response 2".to_string());
793 }
794 });
795
796 executor
797 .spawn({
798 let events = events.clone();
799 async move {
800 let incoming1 = client_incoming
801 .next()
802 .await
803 .unwrap()
804 .into_any()
805 .downcast::<TypedEnvelope<proto::Error>>()
806 .unwrap();
807 events.lock().push(incoming1.payload.message);
808 let incoming2 = client_incoming
809 .next()
810 .await
811 .unwrap()
812 .into_any()
813 .downcast::<TypedEnvelope<proto::Error>>()
814 .unwrap();
815 events.lock().push(incoming2.payload.message);
816
817 // Prevent the connection from being dropped
818 client_incoming.next().await;
819 }
820 })
821 .detach();
822
823 // Allow the request to make some progress before dropping it.
824 cx.background().simulate_random_delay().await;
825 drop(request1_task);
826
827 request2_task.await;
828 assert_eq!(
829 &*events.lock(),
830 &[
831 "message 1".to_string(),
832 "message 2".to_string(),
833 "response 2".to_string()
834 ]
835 );
836 }
837
838 #[gpui::test(iterations = 50)]
839 async fn test_disconnect(cx: &mut TestAppContext) {
840 let executor = cx.foreground();
841
842 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
843
844 let client = Peer::new(0);
845 let (connection_id, io_handler, mut incoming) =
846 client.add_test_connection(client_conn, cx.background());
847
848 let (io_ended_tx, io_ended_rx) = oneshot::channel();
849 executor
850 .spawn(async move {
851 io_handler.await.ok();
852 io_ended_tx.send(()).unwrap();
853 })
854 .detach();
855
856 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
857 executor
858 .spawn(async move {
859 incoming.next().await;
860 messages_ended_tx.send(()).unwrap();
861 })
862 .detach();
863
864 client.disconnect(connection_id);
865
866 let _ = io_ended_rx.await;
867 let _ = messages_ended_rx.await;
868 assert!(server_conn
869 .send(WebSocketMessage::Binary(vec![]))
870 .await
871 .is_err());
872 }
873
874 #[gpui::test(iterations = 50)]
875 async fn test_io_error(cx: &mut TestAppContext) {
876 let executor = cx.foreground();
877 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
878
879 let client = Peer::new(0);
880 let (connection_id, io_handler, mut incoming) =
881 client.add_test_connection(client_conn, cx.background());
882 executor.spawn(io_handler).detach();
883 executor
884 .spawn(async move { incoming.next().await })
885 .detach();
886
887 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
888 let _request = server_conn.rx.next().await.unwrap().unwrap();
889
890 drop(server_conn);
891 assert_eq!(
892 response.await.unwrap_err().to_string(),
893 "connection was closed"
894 );
895 }
896}