1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::BoxStream,
10 FutureExt, SinkExt, StreamExt,
11};
12use parking_lot::{Mutex, RwLock};
13use smol_timeout::TimeoutExt;
14use std::sync::atomic::Ordering::SeqCst;
15use std::{
16 fmt,
17 future::Future,
18 marker::PhantomData,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23 time::Duration,
24};
25use tracing::instrument;
26
27#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
28pub struct ConnectionId(pub u32);
29
30impl fmt::Display for ConnectionId {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 self.0.fmt(f)
33 }
34}
35
36#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
37pub struct PeerId(pub u32);
38
39impl fmt::Display for PeerId {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 self.0.fmt(f)
42 }
43}
44
45pub struct Receipt<T> {
46 pub sender_id: ConnectionId,
47 pub message_id: u32,
48 payload_type: PhantomData<T>,
49}
50
51impl<T> Clone for Receipt<T> {
52 fn clone(&self) -> Self {
53 Self {
54 sender_id: self.sender_id,
55 message_id: self.message_id,
56 payload_type: PhantomData,
57 }
58 }
59}
60
61impl<T> Copy for Receipt<T> {}
62
63pub struct TypedEnvelope<T> {
64 pub sender_id: ConnectionId,
65 pub original_sender_id: Option<PeerId>,
66 pub message_id: u32,
67 pub payload: T,
68}
69
70impl<T> TypedEnvelope<T> {
71 pub fn original_sender_id(&self) -> Result<PeerId> {
72 self.original_sender_id
73 .ok_or_else(|| anyhow!("missing original_sender_id"))
74 }
75}
76
77impl<T: RequestMessage> TypedEnvelope<T> {
78 pub fn receipt(&self) -> Receipt<T> {
79 Receipt {
80 sender_id: self.sender_id,
81 message_id: self.message_id,
82 payload_type: PhantomData,
83 }
84 }
85}
86
87pub struct Peer {
88 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
89 next_connection_id: AtomicU32,
90}
91
92#[derive(Clone)]
93pub struct ConnectionState {
94 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
95 next_message_id: Arc<AtomicU32>,
96 response_channels:
97 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
98}
99
100const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
101const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
102pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
103
104impl Peer {
105 pub fn new() -> Arc<Self> {
106 Arc::new(Self {
107 connections: Default::default(),
108 next_connection_id: Default::default(),
109 })
110 }
111
112 #[instrument(skip_all)]
113 pub async fn add_connection<F, Fut, Out>(
114 self: &Arc<Self>,
115 connection: Connection,
116 create_timer: F,
117 ) -> (
118 ConnectionId,
119 impl Future<Output = anyhow::Result<()>> + Send,
120 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
121 )
122 where
123 F: Send + Fn(Duration) -> Fut,
124 Fut: Send + Future<Output = Out>,
125 Out: Send,
126 {
127 // For outgoing messages, use an unbounded channel so that application code
128 // can always send messages without yielding. For incoming messages, use a
129 // bounded channel so that other peers will receive backpressure if they send
130 // messages faster than this peer can process them.
131 #[cfg(any(test, feature = "test-support"))]
132 const INCOMING_BUFFER_SIZE: usize = 1;
133 #[cfg(not(any(test, feature = "test-support")))]
134 const INCOMING_BUFFER_SIZE: usize = 64;
135 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
136 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
137
138 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
139 let connection_state = ConnectionState {
140 outgoing_tx: outgoing_tx.clone(),
141 next_message_id: Default::default(),
142 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
143 };
144 let mut writer = MessageStream::new(connection.tx);
145 let mut reader = MessageStream::new(connection.rx);
146
147 let this = self.clone();
148 let response_channels = connection_state.response_channels.clone();
149 let handle_io = async move {
150 tracing::debug!(%connection_id, "handle io future: start");
151
152 let _end_connection = util::defer(|| {
153 response_channels.lock().take();
154 this.connections.write().remove(&connection_id);
155 tracing::debug!(%connection_id, "handle io future: end");
156 });
157
158 // Send messages on this frequency so the connection isn't closed.
159 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
160 futures::pin_mut!(keepalive_timer);
161
162 // Disconnect if we don't receive messages at least this frequently.
163 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
164 futures::pin_mut!(receive_timeout);
165
166 loop {
167 tracing::debug!(%connection_id, "outer loop iteration start");
168 let read_message = reader.read().fuse();
169 futures::pin_mut!(read_message);
170
171 loop {
172 tracing::debug!(%connection_id, "inner loop iteration start");
173 futures::select_biased! {
174 outgoing = outgoing_rx.next().fuse() => match outgoing {
175 Some(outgoing) => {
176 tracing::debug!(%connection_id, "outgoing rpc message: writing");
177 if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
178 tracing::debug!(%connection_id, "outgoing rpc message: done writing");
179 result.context("failed to write RPC message")?;
180 tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
181 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
182 } else {
183 tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
184 Err(anyhow!("timed out writing message"))?;
185 }
186 }
187 None => {
188 tracing::debug!(%connection_id, "outgoing rpc message: channel closed");
189 return Ok(())
190 },
191 },
192 incoming = read_message => {
193 let incoming = incoming.context("error reading rpc message from socket")?;
194 tracing::debug!(%connection_id, "incoming rpc message: received");
195 tracing::debug!(%connection_id, "receive timeout: resetting");
196 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
197 if let proto::Message::Envelope(incoming) = incoming {
198 tracing::debug!(%connection_id, "incoming rpc message: processing");
199 match incoming_tx.send(incoming).timeout(RECEIVE_TIMEOUT).await {
200 Some(Ok(_)) => {
201 tracing::debug!(%connection_id, "incoming rpc message: processed");
202 },
203 Some(Err(_)) => {
204 tracing::debug!(%connection_id, "incoming rpc message: channel closed");
205 return Ok(())
206 },
207 None => {
208 tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
209 Err(anyhow!("timed out processing incoming message"))?
210 },
211 }
212 }
213 break;
214 },
215 _ = keepalive_timer => {
216 tracing::debug!(%connection_id, "keepalive interval: pinging");
217 if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
218 tracing::debug!(%connection_id, "keepalive interval: done pinging");
219 result.context("failed to send keepalive")?;
220 tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
221 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
222 } else {
223 tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
224 Err(anyhow!("timed out sending keepalive"))?;
225 }
226 }
227 _ = receive_timeout => {
228 tracing::debug!(%connection_id, "receive timeout: delay between messages too long");
229 Err(anyhow!("delay between messages too long"))?
230 }
231 }
232 }
233 }
234 };
235
236 let response_channels = connection_state.response_channels.clone();
237 self.connections
238 .write()
239 .insert(connection_id, connection_state);
240
241 let incoming_rx = incoming_rx.filter_map(move |incoming| {
242 let response_channels = response_channels.clone();
243 async move {
244 let message_id = incoming.id;
245 tracing::debug!(?incoming, "incoming message future: start");
246 let _end = util::defer(move || {
247 tracing::debug!(
248 %connection_id,
249 message_id,
250 "incoming message future: end"
251 );
252 });
253
254 if let Some(responding_to) = incoming.responding_to {
255 tracing::debug!(
256 %connection_id,
257 message_id,
258 responding_to,
259 "incoming response: received"
260 );
261 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
262 if let Some(tx) = channel {
263 let requester_resumed = oneshot::channel();
264 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
265 tracing::debug!(
266 %connection_id,
267 message_id,
268 responding_to = responding_to,
269 ?error,
270 "incoming response: request future dropped",
271 );
272 }
273
274 tracing::debug!(
275 %connection_id,
276 message_id,
277 responding_to,
278 "incoming response: waiting to resume requester"
279 );
280 let _ = requester_resumed.1.await;
281 tracing::debug!(
282 %connection_id,
283 message_id,
284 responding_to,
285 "incoming response: requester resumed"
286 );
287 } else {
288 tracing::warn!(
289 %connection_id,
290 message_id,
291 responding_to,
292 "incoming response: unknown request"
293 );
294 }
295
296 None
297 } else {
298 tracing::debug!(
299 %connection_id,
300 message_id,
301 "incoming message: received"
302 );
303 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
304 tracing::error!(
305 %connection_id,
306 message_id,
307 "unable to construct a typed envelope"
308 );
309 None
310 })
311 }
312 }
313 });
314 (connection_id, handle_io, incoming_rx.boxed())
315 }
316
317 #[cfg(any(test, feature = "test-support"))]
318 pub async fn add_test_connection(
319 self: &Arc<Self>,
320 connection: Connection,
321 executor: Arc<gpui::executor::Background>,
322 ) -> (
323 ConnectionId,
324 impl Future<Output = anyhow::Result<()>> + Send,
325 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
326 ) {
327 let executor = executor.clone();
328 self.add_connection(connection, move |duration| executor.timer(duration))
329 .await
330 }
331
332 pub fn disconnect(&self, connection_id: ConnectionId) {
333 self.connections.write().remove(&connection_id);
334 }
335
336 pub fn reset(&self) {
337 self.connections.write().clear();
338 }
339
340 pub fn request<T: RequestMessage>(
341 &self,
342 receiver_id: ConnectionId,
343 request: T,
344 ) -> impl Future<Output = Result<T::Response>> {
345 self.request_internal(None, receiver_id, request)
346 }
347
348 pub fn forward_request<T: RequestMessage>(
349 &self,
350 sender_id: ConnectionId,
351 receiver_id: ConnectionId,
352 request: T,
353 ) -> impl Future<Output = Result<T::Response>> {
354 self.request_internal(Some(sender_id), receiver_id, request)
355 }
356
357 pub fn request_internal<T: RequestMessage>(
358 &self,
359 original_sender_id: Option<ConnectionId>,
360 receiver_id: ConnectionId,
361 request: T,
362 ) -> impl Future<Output = Result<T::Response>> {
363 let (tx, rx) = oneshot::channel();
364 let send = self.connection_state(receiver_id).and_then(|connection| {
365 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
366 connection
367 .response_channels
368 .lock()
369 .as_mut()
370 .ok_or_else(|| anyhow!("connection was closed"))?
371 .insert(message_id, tx);
372 connection
373 .outgoing_tx
374 .unbounded_send(proto::Message::Envelope(request.into_envelope(
375 message_id,
376 None,
377 original_sender_id.map(|id| id.0),
378 )))
379 .map_err(|_| anyhow!("connection was closed"))?;
380 Ok(())
381 });
382 async move {
383 send?;
384 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
385 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
386 Err(anyhow!("RPC request failed - {}", error.message))
387 } else {
388 T::Response::from_envelope(response)
389 .ok_or_else(|| anyhow!("received response of the wrong type"))
390 }
391 }
392 }
393
394 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
395 let connection = self.connection_state(receiver_id)?;
396 let message_id = connection
397 .next_message_id
398 .fetch_add(1, atomic::Ordering::SeqCst);
399 connection
400 .outgoing_tx
401 .unbounded_send(proto::Message::Envelope(
402 message.into_envelope(message_id, None, None),
403 ))?;
404 Ok(())
405 }
406
407 pub fn forward_send<T: EnvelopedMessage>(
408 &self,
409 sender_id: ConnectionId,
410 receiver_id: ConnectionId,
411 message: T,
412 ) -> Result<()> {
413 let connection = self.connection_state(receiver_id)?;
414 let message_id = connection
415 .next_message_id
416 .fetch_add(1, atomic::Ordering::SeqCst);
417 connection
418 .outgoing_tx
419 .unbounded_send(proto::Message::Envelope(message.into_envelope(
420 message_id,
421 None,
422 Some(sender_id.0),
423 )))?;
424 Ok(())
425 }
426
427 pub fn respond<T: RequestMessage>(
428 &self,
429 receipt: Receipt<T>,
430 response: T::Response,
431 ) -> Result<()> {
432 let connection = self.connection_state(receipt.sender_id)?;
433 let message_id = connection
434 .next_message_id
435 .fetch_add(1, atomic::Ordering::SeqCst);
436 connection
437 .outgoing_tx
438 .unbounded_send(proto::Message::Envelope(response.into_envelope(
439 message_id,
440 Some(receipt.message_id),
441 None,
442 )))?;
443 Ok(())
444 }
445
446 pub fn respond_with_error<T: RequestMessage>(
447 &self,
448 receipt: Receipt<T>,
449 response: proto::Error,
450 ) -> Result<()> {
451 let connection = self.connection_state(receipt.sender_id)?;
452 let message_id = connection
453 .next_message_id
454 .fetch_add(1, atomic::Ordering::SeqCst);
455 connection
456 .outgoing_tx
457 .unbounded_send(proto::Message::Envelope(response.into_envelope(
458 message_id,
459 Some(receipt.message_id),
460 None,
461 )))?;
462 Ok(())
463 }
464
465 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
466 let connections = self.connections.read();
467 let connection = connections
468 .get(&connection_id)
469 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
470 Ok(connection.clone())
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use crate::TypedEnvelope;
478 use async_tungstenite::tungstenite::Message as WebSocketMessage;
479 use gpui::TestAppContext;
480
481 #[gpui::test(iterations = 50)]
482 async fn test_request_response(cx: &mut TestAppContext) {
483 let executor = cx.foreground();
484
485 // create 2 clients connected to 1 server
486 let server = Peer::new();
487 let client1 = Peer::new();
488 let client2 = Peer::new();
489
490 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
491 Connection::in_memory(cx.background());
492 let (client1_conn_id, io_task1, client1_incoming) = client1
493 .add_test_connection(client1_to_server_conn, cx.background())
494 .await;
495 let (_, io_task2, server_incoming1) = server
496 .add_test_connection(server_to_client_1_conn, cx.background())
497 .await;
498
499 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
500 Connection::in_memory(cx.background());
501 let (client2_conn_id, io_task3, client2_incoming) = client2
502 .add_test_connection(client2_to_server_conn, cx.background())
503 .await;
504 let (_, io_task4, server_incoming2) = server
505 .add_test_connection(server_to_client_2_conn, cx.background())
506 .await;
507
508 executor.spawn(io_task1).detach();
509 executor.spawn(io_task2).detach();
510 executor.spawn(io_task3).detach();
511 executor.spawn(io_task4).detach();
512 executor
513 .spawn(handle_messages(server_incoming1, server.clone()))
514 .detach();
515 executor
516 .spawn(handle_messages(client1_incoming, client1.clone()))
517 .detach();
518 executor
519 .spawn(handle_messages(server_incoming2, server.clone()))
520 .detach();
521 executor
522 .spawn(handle_messages(client2_incoming, client2.clone()))
523 .detach();
524
525 assert_eq!(
526 client1
527 .request(client1_conn_id, proto::Ping {},)
528 .await
529 .unwrap(),
530 proto::Ack {}
531 );
532
533 assert_eq!(
534 client2
535 .request(client2_conn_id, proto::Ping {},)
536 .await
537 .unwrap(),
538 proto::Ack {}
539 );
540
541 assert_eq!(
542 client1
543 .request(client1_conn_id, proto::Test { id: 1 },)
544 .await
545 .unwrap(),
546 proto::Test { id: 1 }
547 );
548
549 assert_eq!(
550 client2
551 .request(client2_conn_id, proto::Test { id: 2 })
552 .await
553 .unwrap(),
554 proto::Test { id: 2 }
555 );
556
557 client1.disconnect(client1_conn_id);
558 client2.disconnect(client1_conn_id);
559
560 async fn handle_messages(
561 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
562 peer: Arc<Peer>,
563 ) -> Result<()> {
564 while let Some(envelope) = messages.next().await {
565 let envelope = envelope.into_any();
566 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
567 let receipt = envelope.receipt();
568 peer.respond(receipt, proto::Ack {})?
569 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
570 {
571 peer.respond(envelope.receipt(), envelope.payload.clone())?
572 } else {
573 panic!("unknown message type");
574 }
575 }
576
577 Ok(())
578 }
579 }
580
581 #[gpui::test(iterations = 50)]
582 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
583 let executor = cx.foreground();
584 let server = Peer::new();
585 let client = Peer::new();
586
587 let (client_to_server_conn, server_to_client_conn, _kill) =
588 Connection::in_memory(cx.background());
589 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
590 .add_test_connection(client_to_server_conn, cx.background())
591 .await;
592 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
593 .add_test_connection(server_to_client_conn, cx.background())
594 .await;
595
596 executor.spawn(io_task1).detach();
597 executor.spawn(io_task2).detach();
598
599 executor
600 .spawn(async move {
601 let request = server_incoming
602 .next()
603 .await
604 .unwrap()
605 .into_any()
606 .downcast::<TypedEnvelope<proto::Ping>>()
607 .unwrap();
608
609 server
610 .send(
611 server_to_client_conn_id,
612 proto::Error {
613 message: "message 1".to_string(),
614 },
615 )
616 .unwrap();
617 server
618 .send(
619 server_to_client_conn_id,
620 proto::Error {
621 message: "message 2".to_string(),
622 },
623 )
624 .unwrap();
625 server.respond(request.receipt(), proto::Ack {}).unwrap();
626
627 // Prevent the connection from being dropped
628 server_incoming.next().await;
629 })
630 .detach();
631
632 let events = Arc::new(Mutex::new(Vec::new()));
633
634 let response = client.request(client_to_server_conn_id, proto::Ping {});
635 let response_task = executor.spawn({
636 let events = events.clone();
637 async move {
638 response.await.unwrap();
639 events.lock().push("response".to_string());
640 }
641 });
642
643 executor
644 .spawn({
645 let events = events.clone();
646 async move {
647 let incoming1 = client_incoming
648 .next()
649 .await
650 .unwrap()
651 .into_any()
652 .downcast::<TypedEnvelope<proto::Error>>()
653 .unwrap();
654 events.lock().push(incoming1.payload.message);
655 let incoming2 = client_incoming
656 .next()
657 .await
658 .unwrap()
659 .into_any()
660 .downcast::<TypedEnvelope<proto::Error>>()
661 .unwrap();
662 events.lock().push(incoming2.payload.message);
663
664 // Prevent the connection from being dropped
665 client_incoming.next().await;
666 }
667 })
668 .detach();
669
670 response_task.await;
671 assert_eq!(
672 &*events.lock(),
673 &[
674 "message 1".to_string(),
675 "message 2".to_string(),
676 "response".to_string()
677 ]
678 );
679 }
680
681 #[gpui::test(iterations = 50)]
682 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
683 let executor = cx.foreground();
684 let server = Peer::new();
685 let client = Peer::new();
686
687 let (client_to_server_conn, server_to_client_conn, _kill) =
688 Connection::in_memory(cx.background());
689 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
690 .add_test_connection(client_to_server_conn, cx.background())
691 .await;
692 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
693 .add_test_connection(server_to_client_conn, cx.background())
694 .await;
695
696 executor.spawn(io_task1).detach();
697 executor.spawn(io_task2).detach();
698
699 executor
700 .spawn(async move {
701 let request1 = server_incoming
702 .next()
703 .await
704 .unwrap()
705 .into_any()
706 .downcast::<TypedEnvelope<proto::Ping>>()
707 .unwrap();
708 let request2 = server_incoming
709 .next()
710 .await
711 .unwrap()
712 .into_any()
713 .downcast::<TypedEnvelope<proto::Ping>>()
714 .unwrap();
715
716 server
717 .send(
718 server_to_client_conn_id,
719 proto::Error {
720 message: "message 1".to_string(),
721 },
722 )
723 .unwrap();
724 server
725 .send(
726 server_to_client_conn_id,
727 proto::Error {
728 message: "message 2".to_string(),
729 },
730 )
731 .unwrap();
732 server.respond(request1.receipt(), proto::Ack {}).unwrap();
733 server.respond(request2.receipt(), proto::Ack {}).unwrap();
734
735 // Prevent the connection from being dropped
736 server_incoming.next().await;
737 })
738 .detach();
739
740 let events = Arc::new(Mutex::new(Vec::new()));
741
742 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
743 let request1_task = executor.spawn(request1);
744 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
745 let request2_task = executor.spawn({
746 let events = events.clone();
747 async move {
748 request2.await.unwrap();
749 events.lock().push("response 2".to_string());
750 }
751 });
752
753 executor
754 .spawn({
755 let events = events.clone();
756 async move {
757 let incoming1 = client_incoming
758 .next()
759 .await
760 .unwrap()
761 .into_any()
762 .downcast::<TypedEnvelope<proto::Error>>()
763 .unwrap();
764 events.lock().push(incoming1.payload.message);
765 let incoming2 = client_incoming
766 .next()
767 .await
768 .unwrap()
769 .into_any()
770 .downcast::<TypedEnvelope<proto::Error>>()
771 .unwrap();
772 events.lock().push(incoming2.payload.message);
773
774 // Prevent the connection from being dropped
775 client_incoming.next().await;
776 }
777 })
778 .detach();
779
780 // Allow the request to make some progress before dropping it.
781 cx.background().simulate_random_delay().await;
782 drop(request1_task);
783
784 request2_task.await;
785 assert_eq!(
786 &*events.lock(),
787 &[
788 "message 1".to_string(),
789 "message 2".to_string(),
790 "response 2".to_string()
791 ]
792 );
793 }
794
795 #[gpui::test(iterations = 50)]
796 async fn test_disconnect(cx: &mut TestAppContext) {
797 let executor = cx.foreground();
798
799 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
800
801 let client = Peer::new();
802 let (connection_id, io_handler, mut incoming) = client
803 .add_test_connection(client_conn, cx.background())
804 .await;
805
806 let (io_ended_tx, io_ended_rx) = oneshot::channel();
807 executor
808 .spawn(async move {
809 io_handler.await.ok();
810 io_ended_tx.send(()).unwrap();
811 })
812 .detach();
813
814 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
815 executor
816 .spawn(async move {
817 incoming.next().await;
818 messages_ended_tx.send(()).unwrap();
819 })
820 .detach();
821
822 client.disconnect(connection_id);
823
824 let _ = io_ended_rx.await;
825 let _ = messages_ended_rx.await;
826 assert!(server_conn
827 .send(WebSocketMessage::Binary(vec![]))
828 .await
829 .is_err());
830 }
831
832 #[gpui::test(iterations = 50)]
833 async fn test_io_error(cx: &mut TestAppContext) {
834 let executor = cx.foreground();
835 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
836
837 let client = Peer::new();
838 let (connection_id, io_handler, mut incoming) = client
839 .add_test_connection(client_conn, cx.background())
840 .await;
841 executor.spawn(io_handler).detach();
842 executor
843 .spawn(async move { incoming.next().await })
844 .detach();
845
846 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
847 let _request = server_conn.rx.next().await.unwrap().unwrap();
848
849 drop(server_conn);
850 assert_eq!(
851 response.await.unwrap_err().to_string(),
852 "connection was closed"
853 );
854 }
855}