1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::BoxStream,
10 FutureExt, SinkExt, StreamExt,
11};
12use parking_lot::{Mutex, RwLock};
13use serde::{ser::SerializeStruct, Serialize};
14use std::sync::atomic::Ordering::SeqCst;
15use std::{
16 fmt,
17 future::Future,
18 marker::PhantomData,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23 time::Duration,
24};
25use tracing::instrument;
26
27#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
28pub struct ConnectionId(pub u32);
29
30impl fmt::Display for ConnectionId {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 self.0.fmt(f)
33 }
34}
35
36#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
37pub struct PeerId(pub u32);
38
39impl fmt::Display for PeerId {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 self.0.fmt(f)
42 }
43}
44
45pub struct Receipt<T> {
46 pub sender_id: ConnectionId,
47 pub message_id: u32,
48 payload_type: PhantomData<T>,
49}
50
51impl<T> Clone for Receipt<T> {
52 fn clone(&self) -> Self {
53 Self {
54 sender_id: self.sender_id,
55 message_id: self.message_id,
56 payload_type: PhantomData,
57 }
58 }
59}
60
61impl<T> Copy for Receipt<T> {}
62
63pub struct TypedEnvelope<T> {
64 pub sender_id: ConnectionId,
65 pub original_sender_id: Option<PeerId>,
66 pub message_id: u32,
67 pub payload: T,
68}
69
70impl<T> TypedEnvelope<T> {
71 pub fn original_sender_id(&self) -> Result<PeerId> {
72 self.original_sender_id
73 .ok_or_else(|| anyhow!("missing original_sender_id"))
74 }
75}
76
77impl<T: RequestMessage> TypedEnvelope<T> {
78 pub fn receipt(&self) -> Receipt<T> {
79 Receipt {
80 sender_id: self.sender_id,
81 message_id: self.message_id,
82 payload_type: PhantomData,
83 }
84 }
85}
86
87pub struct Peer {
88 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
89 next_connection_id: AtomicU32,
90}
91
92#[derive(Clone, Serialize)]
93pub struct ConnectionState {
94 #[serde(skip)]
95 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
96 next_message_id: Arc<AtomicU32>,
97 #[allow(clippy::type_complexity)]
98 #[serde(skip)]
99 response_channels:
100 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
101}
102
103const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
104const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
105pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
106
107impl Peer {
108 pub fn new() -> Arc<Self> {
109 Arc::new(Self {
110 connections: Default::default(),
111 next_connection_id: Default::default(),
112 })
113 }
114
115 #[instrument(skip_all)]
116 pub async fn add_connection<F, Fut, Out>(
117 self: &Arc<Self>,
118 connection: Connection,
119 create_timer: F,
120 ) -> (
121 ConnectionId,
122 impl Future<Output = anyhow::Result<()>> + Send,
123 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
124 )
125 where
126 F: Send + Fn(Duration) -> Fut,
127 Fut: Send + Future<Output = Out>,
128 Out: Send,
129 {
130 // For outgoing messages, use an unbounded channel so that application code
131 // can always send messages without yielding. For incoming messages, use a
132 // bounded channel so that other peers will receive backpressure if they send
133 // messages faster than this peer can process them.
134 #[cfg(any(test, feature = "test-support"))]
135 const INCOMING_BUFFER_SIZE: usize = 1;
136 #[cfg(not(any(test, feature = "test-support")))]
137 const INCOMING_BUFFER_SIZE: usize = 64;
138 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
139 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
140
141 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
142 let connection_state = ConnectionState {
143 outgoing_tx,
144 next_message_id: Default::default(),
145 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
146 };
147 let mut writer = MessageStream::new(connection.tx);
148 let mut reader = MessageStream::new(connection.rx);
149
150 let this = self.clone();
151 let response_channels = connection_state.response_channels.clone();
152 let handle_io = async move {
153 tracing::debug!(%connection_id, "handle io future: start");
154
155 let _end_connection = util::defer(|| {
156 response_channels.lock().take();
157 this.connections.write().remove(&connection_id);
158 tracing::debug!(%connection_id, "handle io future: end");
159 });
160
161 // Send messages on this frequency so the connection isn't closed.
162 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
163 futures::pin_mut!(keepalive_timer);
164
165 // Disconnect if we don't receive messages at least this frequently.
166 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
167 futures::pin_mut!(receive_timeout);
168
169 loop {
170 tracing::debug!(%connection_id, "outer loop iteration start");
171 let read_message = reader.read().fuse();
172 futures::pin_mut!(read_message);
173
174 loop {
175 tracing::debug!(%connection_id, "inner loop iteration start");
176 futures::select_biased! {
177 outgoing = outgoing_rx.next().fuse() => match outgoing {
178 Some(outgoing) => {
179 tracing::debug!(%connection_id, "outgoing rpc message: writing");
180 futures::select_biased! {
181 result = writer.write(outgoing).fuse() => {
182 tracing::debug!(%connection_id, "outgoing rpc message: done writing");
183 result.context("failed to write RPC message")?;
184 tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
185 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
186 }
187 _ = create_timer(WRITE_TIMEOUT).fuse() => {
188 tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
189 Err(anyhow!("timed out writing message"))?;
190 }
191 }
192 }
193 None => {
194 tracing::debug!(%connection_id, "outgoing rpc message: channel closed");
195 return Ok(())
196 },
197 },
198 _ = keepalive_timer => {
199 tracing::debug!(%connection_id, "keepalive interval: pinging");
200 futures::select_biased! {
201 result = writer.write(proto::Message::Ping).fuse() => {
202 tracing::debug!(%connection_id, "keepalive interval: done pinging");
203 result.context("failed to send keepalive")?;
204 tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
205 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
206 }
207 _ = create_timer(WRITE_TIMEOUT).fuse() => {
208 tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
209 Err(anyhow!("timed out sending keepalive"))?;
210 }
211 }
212 }
213 incoming = read_message => {
214 let incoming = incoming.context("error reading rpc message from socket")?;
215 tracing::debug!(%connection_id, "incoming rpc message: received");
216 tracing::debug!(%connection_id, "receive timeout: resetting");
217 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
218 if let proto::Message::Envelope(incoming) = incoming {
219 tracing::debug!(%connection_id, "incoming rpc message: processing");
220 futures::select_biased! {
221 result = incoming_tx.send(incoming).fuse() => match result {
222 Ok(_) => {
223 tracing::debug!(%connection_id, "incoming rpc message: processed");
224 }
225 Err(_) => {
226 tracing::debug!(%connection_id, "incoming rpc message: channel closed");
227 return Ok(())
228 }
229 },
230 _ = create_timer(WRITE_TIMEOUT).fuse() => {
231 tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
232 Err(anyhow!("timed out processing incoming message"))?
233 }
234 }
235 }
236 break;
237 },
238 _ = receive_timeout => {
239 tracing::debug!(%connection_id, "receive timeout: delay between messages too long");
240 Err(anyhow!("delay between messages too long"))?
241 }
242 }
243 }
244 }
245 };
246
247 let response_channels = connection_state.response_channels.clone();
248 self.connections
249 .write()
250 .insert(connection_id, connection_state);
251
252 let incoming_rx = incoming_rx.filter_map(move |incoming| {
253 let response_channels = response_channels.clone();
254 async move {
255 let message_id = incoming.id;
256 tracing::debug!(?incoming, "incoming message future: start");
257 let _end = util::defer(move || {
258 tracing::debug!(
259 %connection_id,
260 message_id,
261 "incoming message future: end"
262 );
263 });
264
265 if let Some(responding_to) = incoming.responding_to {
266 tracing::debug!(
267 %connection_id,
268 message_id,
269 responding_to,
270 "incoming response: received"
271 );
272 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
273 if let Some(tx) = channel {
274 let requester_resumed = oneshot::channel();
275 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
276 tracing::debug!(
277 %connection_id,
278 message_id,
279 responding_to = responding_to,
280 ?error,
281 "incoming response: request future dropped",
282 );
283 }
284
285 tracing::debug!(
286 %connection_id,
287 message_id,
288 responding_to,
289 "incoming response: waiting to resume requester"
290 );
291 let _ = requester_resumed.1.await;
292 tracing::debug!(
293 %connection_id,
294 message_id,
295 responding_to,
296 "incoming response: requester resumed"
297 );
298 } else {
299 tracing::warn!(
300 %connection_id,
301 message_id,
302 responding_to,
303 "incoming response: unknown request"
304 );
305 }
306
307 None
308 } else {
309 tracing::debug!(
310 %connection_id,
311 message_id,
312 "incoming message: received"
313 );
314 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
315 tracing::error!(
316 %connection_id,
317 message_id,
318 "unable to construct a typed envelope"
319 );
320 None
321 })
322 }
323 }
324 });
325 (connection_id, handle_io, incoming_rx.boxed())
326 }
327
328 #[cfg(any(test, feature = "test-support"))]
329 pub async fn add_test_connection(
330 self: &Arc<Self>,
331 connection: Connection,
332 executor: Arc<gpui::executor::Background>,
333 ) -> (
334 ConnectionId,
335 impl Future<Output = anyhow::Result<()>> + Send,
336 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
337 ) {
338 let executor = executor.clone();
339 self.add_connection(connection, move |duration| executor.timer(duration))
340 .await
341 }
342
343 pub fn disconnect(&self, connection_id: ConnectionId) {
344 self.connections.write().remove(&connection_id);
345 }
346
347 pub fn reset(&self) {
348 self.connections.write().clear();
349 }
350
351 pub fn request<T: RequestMessage>(
352 &self,
353 receiver_id: ConnectionId,
354 request: T,
355 ) -> impl Future<Output = Result<T::Response>> {
356 self.request_internal(None, receiver_id, request)
357 }
358
359 pub fn forward_request<T: RequestMessage>(
360 &self,
361 sender_id: ConnectionId,
362 receiver_id: ConnectionId,
363 request: T,
364 ) -> impl Future<Output = Result<T::Response>> {
365 self.request_internal(Some(sender_id), receiver_id, request)
366 }
367
368 pub fn request_internal<T: RequestMessage>(
369 &self,
370 original_sender_id: Option<ConnectionId>,
371 receiver_id: ConnectionId,
372 request: T,
373 ) -> impl Future<Output = Result<T::Response>> {
374 let (tx, rx) = oneshot::channel();
375 let send = self.connection_state(receiver_id).and_then(|connection| {
376 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
377 connection
378 .response_channels
379 .lock()
380 .as_mut()
381 .ok_or_else(|| anyhow!("connection was closed"))?
382 .insert(message_id, tx);
383 connection
384 .outgoing_tx
385 .unbounded_send(proto::Message::Envelope(request.into_envelope(
386 message_id,
387 None,
388 original_sender_id.map(|id| id.0),
389 )))
390 .map_err(|_| anyhow!("connection was closed"))?;
391 Ok(())
392 });
393 async move {
394 send?;
395 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
396 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
397 Err(anyhow!("RPC request failed - {}", error.message))
398 } else {
399 T::Response::from_envelope(response)
400 .ok_or_else(|| anyhow!("received response of the wrong type"))
401 }
402 }
403 }
404
405 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
406 let connection = self.connection_state(receiver_id)?;
407 let message_id = connection
408 .next_message_id
409 .fetch_add(1, atomic::Ordering::SeqCst);
410 connection
411 .outgoing_tx
412 .unbounded_send(proto::Message::Envelope(
413 message.into_envelope(message_id, None, None),
414 ))?;
415 Ok(())
416 }
417
418 pub fn forward_send<T: EnvelopedMessage>(
419 &self,
420 sender_id: ConnectionId,
421 receiver_id: ConnectionId,
422 message: T,
423 ) -> Result<()> {
424 let connection = self.connection_state(receiver_id)?;
425 let message_id = connection
426 .next_message_id
427 .fetch_add(1, atomic::Ordering::SeqCst);
428 connection
429 .outgoing_tx
430 .unbounded_send(proto::Message::Envelope(message.into_envelope(
431 message_id,
432 None,
433 Some(sender_id.0),
434 )))?;
435 Ok(())
436 }
437
438 pub fn respond<T: RequestMessage>(
439 &self,
440 receipt: Receipt<T>,
441 response: T::Response,
442 ) -> Result<()> {
443 let connection = self.connection_state(receipt.sender_id)?;
444 let message_id = connection
445 .next_message_id
446 .fetch_add(1, atomic::Ordering::SeqCst);
447 connection
448 .outgoing_tx
449 .unbounded_send(proto::Message::Envelope(response.into_envelope(
450 message_id,
451 Some(receipt.message_id),
452 None,
453 )))?;
454 Ok(())
455 }
456
457 pub fn respond_with_error<T: RequestMessage>(
458 &self,
459 receipt: Receipt<T>,
460 response: proto::Error,
461 ) -> Result<()> {
462 let connection = self.connection_state(receipt.sender_id)?;
463 let message_id = connection
464 .next_message_id
465 .fetch_add(1, atomic::Ordering::SeqCst);
466 connection
467 .outgoing_tx
468 .unbounded_send(proto::Message::Envelope(response.into_envelope(
469 message_id,
470 Some(receipt.message_id),
471 None,
472 )))?;
473 Ok(())
474 }
475
476 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
477 let connections = self.connections.read();
478 let connection = connections
479 .get(&connection_id)
480 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
481 Ok(connection.clone())
482 }
483}
484
485impl Serialize for Peer {
486 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
487 where
488 S: serde::Serializer,
489 {
490 let mut state = serializer.serialize_struct("Peer", 2)?;
491 state.serialize_field("connections", &*self.connections.read())?;
492 state.end()
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use crate::TypedEnvelope;
500 use async_tungstenite::tungstenite::Message as WebSocketMessage;
501 use gpui::TestAppContext;
502
503 #[ctor::ctor]
504 fn init_logger() {
505 if std::env::var("RUST_LOG").is_ok() {
506 env_logger::init();
507 }
508 }
509
510 #[gpui::test(iterations = 50)]
511 async fn test_request_response(cx: &mut TestAppContext) {
512 let executor = cx.foreground();
513
514 // create 2 clients connected to 1 server
515 let server = Peer::new();
516 let client1 = Peer::new();
517 let client2 = Peer::new();
518
519 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
520 Connection::in_memory(cx.background());
521 let (client1_conn_id, io_task1, client1_incoming) = client1
522 .add_test_connection(client1_to_server_conn, cx.background())
523 .await;
524 let (_, io_task2, server_incoming1) = server
525 .add_test_connection(server_to_client_1_conn, cx.background())
526 .await;
527
528 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
529 Connection::in_memory(cx.background());
530 let (client2_conn_id, io_task3, client2_incoming) = client2
531 .add_test_connection(client2_to_server_conn, cx.background())
532 .await;
533 let (_, io_task4, server_incoming2) = server
534 .add_test_connection(server_to_client_2_conn, cx.background())
535 .await;
536
537 executor.spawn(io_task1).detach();
538 executor.spawn(io_task2).detach();
539 executor.spawn(io_task3).detach();
540 executor.spawn(io_task4).detach();
541 executor
542 .spawn(handle_messages(server_incoming1, server.clone()))
543 .detach();
544 executor
545 .spawn(handle_messages(client1_incoming, client1.clone()))
546 .detach();
547 executor
548 .spawn(handle_messages(server_incoming2, server.clone()))
549 .detach();
550 executor
551 .spawn(handle_messages(client2_incoming, client2.clone()))
552 .detach();
553
554 assert_eq!(
555 client1
556 .request(client1_conn_id, proto::Ping {},)
557 .await
558 .unwrap(),
559 proto::Ack {}
560 );
561
562 assert_eq!(
563 client2
564 .request(client2_conn_id, proto::Ping {},)
565 .await
566 .unwrap(),
567 proto::Ack {}
568 );
569
570 assert_eq!(
571 client1
572 .request(client1_conn_id, proto::Test { id: 1 },)
573 .await
574 .unwrap(),
575 proto::Test { id: 1 }
576 );
577
578 assert_eq!(
579 client2
580 .request(client2_conn_id, proto::Test { id: 2 })
581 .await
582 .unwrap(),
583 proto::Test { id: 2 }
584 );
585
586 client1.disconnect(client1_conn_id);
587 client2.disconnect(client1_conn_id);
588
589 async fn handle_messages(
590 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
591 peer: Arc<Peer>,
592 ) -> Result<()> {
593 while let Some(envelope) = messages.next().await {
594 let envelope = envelope.into_any();
595 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
596 let receipt = envelope.receipt();
597 peer.respond(receipt, proto::Ack {})?
598 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
599 {
600 peer.respond(envelope.receipt(), envelope.payload.clone())?
601 } else {
602 panic!("unknown message type");
603 }
604 }
605
606 Ok(())
607 }
608 }
609
610 #[gpui::test(iterations = 50)]
611 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
612 let executor = cx.foreground();
613 let server = Peer::new();
614 let client = Peer::new();
615
616 let (client_to_server_conn, server_to_client_conn, _kill) =
617 Connection::in_memory(cx.background());
618 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
619 .add_test_connection(client_to_server_conn, cx.background())
620 .await;
621 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
622 .add_test_connection(server_to_client_conn, cx.background())
623 .await;
624
625 executor.spawn(io_task1).detach();
626 executor.spawn(io_task2).detach();
627
628 executor
629 .spawn(async move {
630 let request = server_incoming
631 .next()
632 .await
633 .unwrap()
634 .into_any()
635 .downcast::<TypedEnvelope<proto::Ping>>()
636 .unwrap();
637
638 server
639 .send(
640 server_to_client_conn_id,
641 proto::Error {
642 message: "message 1".to_string(),
643 },
644 )
645 .unwrap();
646 server
647 .send(
648 server_to_client_conn_id,
649 proto::Error {
650 message: "message 2".to_string(),
651 },
652 )
653 .unwrap();
654 server.respond(request.receipt(), proto::Ack {}).unwrap();
655
656 // Prevent the connection from being dropped
657 server_incoming.next().await;
658 })
659 .detach();
660
661 let events = Arc::new(Mutex::new(Vec::new()));
662
663 let response = client.request(client_to_server_conn_id, proto::Ping {});
664 let response_task = executor.spawn({
665 let events = events.clone();
666 async move {
667 response.await.unwrap();
668 events.lock().push("response".to_string());
669 }
670 });
671
672 executor
673 .spawn({
674 let events = events.clone();
675 async move {
676 let incoming1 = client_incoming
677 .next()
678 .await
679 .unwrap()
680 .into_any()
681 .downcast::<TypedEnvelope<proto::Error>>()
682 .unwrap();
683 events.lock().push(incoming1.payload.message);
684 let incoming2 = client_incoming
685 .next()
686 .await
687 .unwrap()
688 .into_any()
689 .downcast::<TypedEnvelope<proto::Error>>()
690 .unwrap();
691 events.lock().push(incoming2.payload.message);
692
693 // Prevent the connection from being dropped
694 client_incoming.next().await;
695 }
696 })
697 .detach();
698
699 response_task.await;
700 assert_eq!(
701 &*events.lock(),
702 &[
703 "message 1".to_string(),
704 "message 2".to_string(),
705 "response".to_string()
706 ]
707 );
708 }
709
710 #[gpui::test(iterations = 50)]
711 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
712 let executor = cx.foreground();
713 let server = Peer::new();
714 let client = Peer::new();
715
716 let (client_to_server_conn, server_to_client_conn, _kill) =
717 Connection::in_memory(cx.background());
718 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
719 .add_test_connection(client_to_server_conn, cx.background())
720 .await;
721 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
722 .add_test_connection(server_to_client_conn, cx.background())
723 .await;
724
725 executor.spawn(io_task1).detach();
726 executor.spawn(io_task2).detach();
727
728 executor
729 .spawn(async move {
730 let request1 = server_incoming
731 .next()
732 .await
733 .unwrap()
734 .into_any()
735 .downcast::<TypedEnvelope<proto::Ping>>()
736 .unwrap();
737 let request2 = server_incoming
738 .next()
739 .await
740 .unwrap()
741 .into_any()
742 .downcast::<TypedEnvelope<proto::Ping>>()
743 .unwrap();
744
745 server
746 .send(
747 server_to_client_conn_id,
748 proto::Error {
749 message: "message 1".to_string(),
750 },
751 )
752 .unwrap();
753 server
754 .send(
755 server_to_client_conn_id,
756 proto::Error {
757 message: "message 2".to_string(),
758 },
759 )
760 .unwrap();
761 server.respond(request1.receipt(), proto::Ack {}).unwrap();
762 server.respond(request2.receipt(), proto::Ack {}).unwrap();
763
764 // Prevent the connection from being dropped
765 server_incoming.next().await;
766 })
767 .detach();
768
769 let events = Arc::new(Mutex::new(Vec::new()));
770
771 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
772 let request1_task = executor.spawn(request1);
773 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
774 let request2_task = executor.spawn({
775 let events = events.clone();
776 async move {
777 request2.await.unwrap();
778 events.lock().push("response 2".to_string());
779 }
780 });
781
782 executor
783 .spawn({
784 let events = events.clone();
785 async move {
786 let incoming1 = client_incoming
787 .next()
788 .await
789 .unwrap()
790 .into_any()
791 .downcast::<TypedEnvelope<proto::Error>>()
792 .unwrap();
793 events.lock().push(incoming1.payload.message);
794 let incoming2 = client_incoming
795 .next()
796 .await
797 .unwrap()
798 .into_any()
799 .downcast::<TypedEnvelope<proto::Error>>()
800 .unwrap();
801 events.lock().push(incoming2.payload.message);
802
803 // Prevent the connection from being dropped
804 client_incoming.next().await;
805 }
806 })
807 .detach();
808
809 // Allow the request to make some progress before dropping it.
810 cx.background().simulate_random_delay().await;
811 drop(request1_task);
812
813 request2_task.await;
814 assert_eq!(
815 &*events.lock(),
816 &[
817 "message 1".to_string(),
818 "message 2".to_string(),
819 "response 2".to_string()
820 ]
821 );
822 }
823
824 #[gpui::test(iterations = 50)]
825 async fn test_disconnect(cx: &mut TestAppContext) {
826 let executor = cx.foreground();
827
828 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
829
830 let client = Peer::new();
831 let (connection_id, io_handler, mut incoming) = client
832 .add_test_connection(client_conn, cx.background())
833 .await;
834
835 let (io_ended_tx, io_ended_rx) = oneshot::channel();
836 executor
837 .spawn(async move {
838 io_handler.await.ok();
839 io_ended_tx.send(()).unwrap();
840 })
841 .detach();
842
843 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
844 executor
845 .spawn(async move {
846 incoming.next().await;
847 messages_ended_tx.send(()).unwrap();
848 })
849 .detach();
850
851 client.disconnect(connection_id);
852
853 let _ = io_ended_rx.await;
854 let _ = messages_ended_rx.await;
855 assert!(server_conn
856 .send(WebSocketMessage::Binary(vec![]))
857 .await
858 .is_err());
859 }
860
861 #[gpui::test(iterations = 50)]
862 async fn test_io_error(cx: &mut TestAppContext) {
863 let executor = cx.foreground();
864 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
865
866 let client = Peer::new();
867 let (connection_id, io_handler, mut incoming) = client
868 .add_test_connection(client_conn, cx.background())
869 .await;
870 executor.spawn(io_handler).detach();
871 executor
872 .spawn(async move { incoming.next().await })
873 .detach();
874
875 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
876 let _request = server_conn.rx.next().await.unwrap().unwrap();
877
878 drop(server_conn);
879 assert_eq!(
880 response.await.unwrap_err().to_string(),
881 "connection was closed"
882 );
883 }
884}