1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::BoxStream,
10 FutureExt, SinkExt, StreamExt,
11};
12use parking_lot::{Mutex, RwLock};
13use serde::{ser::SerializeStruct, Serialize};
14use smol_timeout::TimeoutExt;
15use std::sync::atomic::Ordering::SeqCst;
16use std::{
17 fmt,
18 future::Future,
19 marker::PhantomData,
20 sync::{
21 atomic::{self, AtomicU32},
22 Arc,
23 },
24 time::Duration,
25};
26use tracing::instrument;
27
28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug, Serialize)]
29pub struct ConnectionId(pub u32);
30
31impl fmt::Display for ConnectionId {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 self.0.fmt(f)
34 }
35}
36
37#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
38pub struct PeerId(pub u32);
39
40impl fmt::Display for PeerId {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 self.0.fmt(f)
43 }
44}
45
46pub struct Receipt<T> {
47 pub sender_id: ConnectionId,
48 pub message_id: u32,
49 payload_type: PhantomData<T>,
50}
51
52impl<T> Clone for Receipt<T> {
53 fn clone(&self) -> Self {
54 Self {
55 sender_id: self.sender_id,
56 message_id: self.message_id,
57 payload_type: PhantomData,
58 }
59 }
60}
61
62impl<T> Copy for Receipt<T> {}
63
64pub struct TypedEnvelope<T> {
65 pub sender_id: ConnectionId,
66 pub original_sender_id: Option<PeerId>,
67 pub message_id: u32,
68 pub payload: T,
69}
70
71impl<T> TypedEnvelope<T> {
72 pub fn original_sender_id(&self) -> Result<PeerId> {
73 self.original_sender_id
74 .ok_or_else(|| anyhow!("missing original_sender_id"))
75 }
76}
77
78impl<T: RequestMessage> TypedEnvelope<T> {
79 pub fn receipt(&self) -> Receipt<T> {
80 Receipt {
81 sender_id: self.sender_id,
82 message_id: self.message_id,
83 payload_type: PhantomData,
84 }
85 }
86}
87
88pub struct Peer {
89 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
90 next_connection_id: AtomicU32,
91}
92
93#[derive(Clone, Serialize)]
94pub struct ConnectionState {
95 #[serde(skip)]
96 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
97 next_message_id: Arc<AtomicU32>,
98 #[serde(skip)]
99 response_channels:
100 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
101}
102
103const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
104const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
105pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
106
107impl Peer {
108 pub fn new() -> Arc<Self> {
109 Arc::new(Self {
110 connections: Default::default(),
111 next_connection_id: Default::default(),
112 })
113 }
114
115 #[instrument(skip_all)]
116 pub async fn add_connection<F, Fut, Out>(
117 self: &Arc<Self>,
118 connection: Connection,
119 create_timer: F,
120 ) -> (
121 ConnectionId,
122 impl Future<Output = anyhow::Result<()>> + Send,
123 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
124 )
125 where
126 F: Send + Fn(Duration) -> Fut,
127 Fut: Send + Future<Output = Out>,
128 Out: Send,
129 {
130 // For outgoing messages, use an unbounded channel so that application code
131 // can always send messages without yielding. For incoming messages, use a
132 // bounded channel so that other peers will receive backpressure if they send
133 // messages faster than this peer can process them.
134 #[cfg(any(test, feature = "test-support"))]
135 const INCOMING_BUFFER_SIZE: usize = 1;
136 #[cfg(not(any(test, feature = "test-support")))]
137 const INCOMING_BUFFER_SIZE: usize = 64;
138 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
139 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
140
141 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
142 let connection_state = ConnectionState {
143 outgoing_tx: outgoing_tx.clone(),
144 next_message_id: Default::default(),
145 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
146 };
147 let mut writer = MessageStream::new(connection.tx);
148 let mut reader = MessageStream::new(connection.rx);
149
150 let this = self.clone();
151 let response_channels = connection_state.response_channels.clone();
152 let handle_io = async move {
153 tracing::debug!(%connection_id, "handle io future: start");
154
155 let _end_connection = util::defer(|| {
156 response_channels.lock().take();
157 this.connections.write().remove(&connection_id);
158 tracing::debug!(%connection_id, "handle io future: end");
159 });
160
161 // Send messages on this frequency so the connection isn't closed.
162 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
163 futures::pin_mut!(keepalive_timer);
164
165 // Disconnect if we don't receive messages at least this frequently.
166 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
167 futures::pin_mut!(receive_timeout);
168
169 loop {
170 tracing::debug!(%connection_id, "outer loop iteration start");
171 let read_message = reader.read().fuse();
172 futures::pin_mut!(read_message);
173
174 loop {
175 tracing::debug!(%connection_id, "inner loop iteration start");
176 futures::select_biased! {
177 outgoing = outgoing_rx.next().fuse() => match outgoing {
178 Some(outgoing) => {
179 tracing::debug!(%connection_id, "outgoing rpc message: writing");
180 if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
181 tracing::debug!(%connection_id, "outgoing rpc message: done writing");
182 result.context("failed to write RPC message")?;
183 tracing::debug!(%connection_id, "keepalive interval: resetting after sending message");
184 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
185 } else {
186 tracing::debug!(%connection_id, "outgoing rpc message: writing timed out");
187 Err(anyhow!("timed out writing message"))?;
188 }
189 }
190 None => {
191 tracing::debug!(%connection_id, "outgoing rpc message: channel closed");
192 return Ok(())
193 },
194 },
195 incoming = read_message => {
196 let incoming = incoming.context("error reading rpc message from socket")?;
197 tracing::debug!(%connection_id, "incoming rpc message: received");
198 tracing::debug!(%connection_id, "receive timeout: resetting");
199 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
200 if let proto::Message::Envelope(incoming) = incoming {
201 tracing::debug!(%connection_id, "incoming rpc message: processing");
202 match incoming_tx.send(incoming).timeout(RECEIVE_TIMEOUT).await {
203 Some(Ok(_)) => {
204 tracing::debug!(%connection_id, "incoming rpc message: processed");
205 },
206 Some(Err(_)) => {
207 tracing::debug!(%connection_id, "incoming rpc message: channel closed");
208 return Ok(())
209 },
210 None => {
211 tracing::debug!(%connection_id, "incoming rpc message: processing timed out");
212 Err(anyhow!("timed out processing incoming message"))?
213 },
214 }
215 }
216 break;
217 },
218 _ = keepalive_timer => {
219 tracing::debug!(%connection_id, "keepalive interval: pinging");
220 if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
221 tracing::debug!(%connection_id, "keepalive interval: done pinging");
222 result.context("failed to send keepalive")?;
223 tracing::debug!(%connection_id, "keepalive interval: resetting after pinging");
224 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
225 } else {
226 tracing::debug!(%connection_id, "keepalive interval: pinging timed out");
227 Err(anyhow!("timed out sending keepalive"))?;
228 }
229 }
230 _ = receive_timeout => {
231 tracing::debug!(%connection_id, "receive timeout: delay between messages too long");
232 Err(anyhow!("delay between messages too long"))?
233 }
234 }
235 }
236 }
237 };
238
239 let response_channels = connection_state.response_channels.clone();
240 self.connections
241 .write()
242 .insert(connection_id, connection_state);
243
244 let incoming_rx = incoming_rx.filter_map(move |incoming| {
245 let response_channels = response_channels.clone();
246 async move {
247 let message_id = incoming.id;
248 tracing::debug!(?incoming, "incoming message future: start");
249 let _end = util::defer(move || {
250 tracing::debug!(
251 %connection_id,
252 message_id,
253 "incoming message future: end"
254 );
255 });
256
257 if let Some(responding_to) = incoming.responding_to {
258 tracing::debug!(
259 %connection_id,
260 message_id,
261 responding_to,
262 "incoming response: received"
263 );
264 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
265 if let Some(tx) = channel {
266 let requester_resumed = oneshot::channel();
267 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
268 tracing::debug!(
269 %connection_id,
270 message_id,
271 responding_to = responding_to,
272 ?error,
273 "incoming response: request future dropped",
274 );
275 }
276
277 tracing::debug!(
278 %connection_id,
279 message_id,
280 responding_to,
281 "incoming response: waiting to resume requester"
282 );
283 let _ = requester_resumed.1.await;
284 tracing::debug!(
285 %connection_id,
286 message_id,
287 responding_to,
288 "incoming response: requester resumed"
289 );
290 } else {
291 tracing::warn!(
292 %connection_id,
293 message_id,
294 responding_to,
295 "incoming response: unknown request"
296 );
297 }
298
299 None
300 } else {
301 tracing::debug!(
302 %connection_id,
303 message_id,
304 "incoming message: received"
305 );
306 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
307 tracing::error!(
308 %connection_id,
309 message_id,
310 "unable to construct a typed envelope"
311 );
312 None
313 })
314 }
315 }
316 });
317 (connection_id, handle_io, incoming_rx.boxed())
318 }
319
320 #[cfg(any(test, feature = "test-support"))]
321 pub async fn add_test_connection(
322 self: &Arc<Self>,
323 connection: Connection,
324 executor: Arc<gpui::executor::Background>,
325 ) -> (
326 ConnectionId,
327 impl Future<Output = anyhow::Result<()>> + Send,
328 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
329 ) {
330 let executor = executor.clone();
331 self.add_connection(connection, move |duration| executor.timer(duration))
332 .await
333 }
334
335 pub fn disconnect(&self, connection_id: ConnectionId) {
336 self.connections.write().remove(&connection_id);
337 }
338
339 pub fn reset(&self) {
340 self.connections.write().clear();
341 }
342
343 pub fn request<T: RequestMessage>(
344 &self,
345 receiver_id: ConnectionId,
346 request: T,
347 ) -> impl Future<Output = Result<T::Response>> {
348 self.request_internal(None, receiver_id, request)
349 }
350
351 pub fn forward_request<T: RequestMessage>(
352 &self,
353 sender_id: ConnectionId,
354 receiver_id: ConnectionId,
355 request: T,
356 ) -> impl Future<Output = Result<T::Response>> {
357 self.request_internal(Some(sender_id), receiver_id, request)
358 }
359
360 pub fn request_internal<T: RequestMessage>(
361 &self,
362 original_sender_id: Option<ConnectionId>,
363 receiver_id: ConnectionId,
364 request: T,
365 ) -> impl Future<Output = Result<T::Response>> {
366 let (tx, rx) = oneshot::channel();
367 let send = self.connection_state(receiver_id).and_then(|connection| {
368 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
369 connection
370 .response_channels
371 .lock()
372 .as_mut()
373 .ok_or_else(|| anyhow!("connection was closed"))?
374 .insert(message_id, tx);
375 connection
376 .outgoing_tx
377 .unbounded_send(proto::Message::Envelope(request.into_envelope(
378 message_id,
379 None,
380 original_sender_id.map(|id| id.0),
381 )))
382 .map_err(|_| anyhow!("connection was closed"))?;
383 Ok(())
384 });
385 async move {
386 send?;
387 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
388 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
389 Err(anyhow!("RPC request failed - {}", error.message))
390 } else {
391 T::Response::from_envelope(response)
392 .ok_or_else(|| anyhow!("received response of the wrong type"))
393 }
394 }
395 }
396
397 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
398 let connection = self.connection_state(receiver_id)?;
399 let message_id = connection
400 .next_message_id
401 .fetch_add(1, atomic::Ordering::SeqCst);
402 connection
403 .outgoing_tx
404 .unbounded_send(proto::Message::Envelope(
405 message.into_envelope(message_id, None, None),
406 ))?;
407 Ok(())
408 }
409
410 pub fn forward_send<T: EnvelopedMessage>(
411 &self,
412 sender_id: ConnectionId,
413 receiver_id: ConnectionId,
414 message: T,
415 ) -> Result<()> {
416 let connection = self.connection_state(receiver_id)?;
417 let message_id = connection
418 .next_message_id
419 .fetch_add(1, atomic::Ordering::SeqCst);
420 connection
421 .outgoing_tx
422 .unbounded_send(proto::Message::Envelope(message.into_envelope(
423 message_id,
424 None,
425 Some(sender_id.0),
426 )))?;
427 Ok(())
428 }
429
430 pub fn respond<T: RequestMessage>(
431 &self,
432 receipt: Receipt<T>,
433 response: T::Response,
434 ) -> Result<()> {
435 let connection = self.connection_state(receipt.sender_id)?;
436 let message_id = connection
437 .next_message_id
438 .fetch_add(1, atomic::Ordering::SeqCst);
439 connection
440 .outgoing_tx
441 .unbounded_send(proto::Message::Envelope(response.into_envelope(
442 message_id,
443 Some(receipt.message_id),
444 None,
445 )))?;
446 Ok(())
447 }
448
449 pub fn respond_with_error<T: RequestMessage>(
450 &self,
451 receipt: Receipt<T>,
452 response: proto::Error,
453 ) -> Result<()> {
454 let connection = self.connection_state(receipt.sender_id)?;
455 let message_id = connection
456 .next_message_id
457 .fetch_add(1, atomic::Ordering::SeqCst);
458 connection
459 .outgoing_tx
460 .unbounded_send(proto::Message::Envelope(response.into_envelope(
461 message_id,
462 Some(receipt.message_id),
463 None,
464 )))?;
465 Ok(())
466 }
467
468 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
469 let connections = self.connections.read();
470 let connection = connections
471 .get(&connection_id)
472 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
473 Ok(connection.clone())
474 }
475}
476
477impl Serialize for Peer {
478 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
479 where
480 S: serde::Serializer,
481 {
482 let mut state = serializer.serialize_struct("Peer", 2)?;
483 state.serialize_field("connections", &*self.connections.read())?;
484 state.end()
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491 use crate::TypedEnvelope;
492 use async_tungstenite::tungstenite::Message as WebSocketMessage;
493 use gpui::TestAppContext;
494
495 #[gpui::test(iterations = 50)]
496 async fn test_request_response(cx: &mut TestAppContext) {
497 let executor = cx.foreground();
498
499 // create 2 clients connected to 1 server
500 let server = Peer::new();
501 let client1 = Peer::new();
502 let client2 = Peer::new();
503
504 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
505 Connection::in_memory(cx.background());
506 let (client1_conn_id, io_task1, client1_incoming) = client1
507 .add_test_connection(client1_to_server_conn, cx.background())
508 .await;
509 let (_, io_task2, server_incoming1) = server
510 .add_test_connection(server_to_client_1_conn, cx.background())
511 .await;
512
513 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
514 Connection::in_memory(cx.background());
515 let (client2_conn_id, io_task3, client2_incoming) = client2
516 .add_test_connection(client2_to_server_conn, cx.background())
517 .await;
518 let (_, io_task4, server_incoming2) = server
519 .add_test_connection(server_to_client_2_conn, cx.background())
520 .await;
521
522 executor.spawn(io_task1).detach();
523 executor.spawn(io_task2).detach();
524 executor.spawn(io_task3).detach();
525 executor.spawn(io_task4).detach();
526 executor
527 .spawn(handle_messages(server_incoming1, server.clone()))
528 .detach();
529 executor
530 .spawn(handle_messages(client1_incoming, client1.clone()))
531 .detach();
532 executor
533 .spawn(handle_messages(server_incoming2, server.clone()))
534 .detach();
535 executor
536 .spawn(handle_messages(client2_incoming, client2.clone()))
537 .detach();
538
539 assert_eq!(
540 client1
541 .request(client1_conn_id, proto::Ping {},)
542 .await
543 .unwrap(),
544 proto::Ack {}
545 );
546
547 assert_eq!(
548 client2
549 .request(client2_conn_id, proto::Ping {},)
550 .await
551 .unwrap(),
552 proto::Ack {}
553 );
554
555 assert_eq!(
556 client1
557 .request(client1_conn_id, proto::Test { id: 1 },)
558 .await
559 .unwrap(),
560 proto::Test { id: 1 }
561 );
562
563 assert_eq!(
564 client2
565 .request(client2_conn_id, proto::Test { id: 2 })
566 .await
567 .unwrap(),
568 proto::Test { id: 2 }
569 );
570
571 client1.disconnect(client1_conn_id);
572 client2.disconnect(client1_conn_id);
573
574 async fn handle_messages(
575 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
576 peer: Arc<Peer>,
577 ) -> Result<()> {
578 while let Some(envelope) = messages.next().await {
579 let envelope = envelope.into_any();
580 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
581 let receipt = envelope.receipt();
582 peer.respond(receipt, proto::Ack {})?
583 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
584 {
585 peer.respond(envelope.receipt(), envelope.payload.clone())?
586 } else {
587 panic!("unknown message type");
588 }
589 }
590
591 Ok(())
592 }
593 }
594
595 #[gpui::test(iterations = 50)]
596 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
597 let executor = cx.foreground();
598 let server = Peer::new();
599 let client = Peer::new();
600
601 let (client_to_server_conn, server_to_client_conn, _kill) =
602 Connection::in_memory(cx.background());
603 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
604 .add_test_connection(client_to_server_conn, cx.background())
605 .await;
606 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
607 .add_test_connection(server_to_client_conn, cx.background())
608 .await;
609
610 executor.spawn(io_task1).detach();
611 executor.spawn(io_task2).detach();
612
613 executor
614 .spawn(async move {
615 let request = server_incoming
616 .next()
617 .await
618 .unwrap()
619 .into_any()
620 .downcast::<TypedEnvelope<proto::Ping>>()
621 .unwrap();
622
623 server
624 .send(
625 server_to_client_conn_id,
626 proto::Error {
627 message: "message 1".to_string(),
628 },
629 )
630 .unwrap();
631 server
632 .send(
633 server_to_client_conn_id,
634 proto::Error {
635 message: "message 2".to_string(),
636 },
637 )
638 .unwrap();
639 server.respond(request.receipt(), proto::Ack {}).unwrap();
640
641 // Prevent the connection from being dropped
642 server_incoming.next().await;
643 })
644 .detach();
645
646 let events = Arc::new(Mutex::new(Vec::new()));
647
648 let response = client.request(client_to_server_conn_id, proto::Ping {});
649 let response_task = executor.spawn({
650 let events = events.clone();
651 async move {
652 response.await.unwrap();
653 events.lock().push("response".to_string());
654 }
655 });
656
657 executor
658 .spawn({
659 let events = events.clone();
660 async move {
661 let incoming1 = client_incoming
662 .next()
663 .await
664 .unwrap()
665 .into_any()
666 .downcast::<TypedEnvelope<proto::Error>>()
667 .unwrap();
668 events.lock().push(incoming1.payload.message);
669 let incoming2 = client_incoming
670 .next()
671 .await
672 .unwrap()
673 .into_any()
674 .downcast::<TypedEnvelope<proto::Error>>()
675 .unwrap();
676 events.lock().push(incoming2.payload.message);
677
678 // Prevent the connection from being dropped
679 client_incoming.next().await;
680 }
681 })
682 .detach();
683
684 response_task.await;
685 assert_eq!(
686 &*events.lock(),
687 &[
688 "message 1".to_string(),
689 "message 2".to_string(),
690 "response".to_string()
691 ]
692 );
693 }
694
695 #[gpui::test(iterations = 50)]
696 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
697 let executor = cx.foreground();
698 let server = Peer::new();
699 let client = Peer::new();
700
701 let (client_to_server_conn, server_to_client_conn, _kill) =
702 Connection::in_memory(cx.background());
703 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
704 .add_test_connection(client_to_server_conn, cx.background())
705 .await;
706 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
707 .add_test_connection(server_to_client_conn, cx.background())
708 .await;
709
710 executor.spawn(io_task1).detach();
711 executor.spawn(io_task2).detach();
712
713 executor
714 .spawn(async move {
715 let request1 = server_incoming
716 .next()
717 .await
718 .unwrap()
719 .into_any()
720 .downcast::<TypedEnvelope<proto::Ping>>()
721 .unwrap();
722 let request2 = server_incoming
723 .next()
724 .await
725 .unwrap()
726 .into_any()
727 .downcast::<TypedEnvelope<proto::Ping>>()
728 .unwrap();
729
730 server
731 .send(
732 server_to_client_conn_id,
733 proto::Error {
734 message: "message 1".to_string(),
735 },
736 )
737 .unwrap();
738 server
739 .send(
740 server_to_client_conn_id,
741 proto::Error {
742 message: "message 2".to_string(),
743 },
744 )
745 .unwrap();
746 server.respond(request1.receipt(), proto::Ack {}).unwrap();
747 server.respond(request2.receipt(), proto::Ack {}).unwrap();
748
749 // Prevent the connection from being dropped
750 server_incoming.next().await;
751 })
752 .detach();
753
754 let events = Arc::new(Mutex::new(Vec::new()));
755
756 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
757 let request1_task = executor.spawn(request1);
758 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
759 let request2_task = executor.spawn({
760 let events = events.clone();
761 async move {
762 request2.await.unwrap();
763 events.lock().push("response 2".to_string());
764 }
765 });
766
767 executor
768 .spawn({
769 let events = events.clone();
770 async move {
771 let incoming1 = client_incoming
772 .next()
773 .await
774 .unwrap()
775 .into_any()
776 .downcast::<TypedEnvelope<proto::Error>>()
777 .unwrap();
778 events.lock().push(incoming1.payload.message);
779 let incoming2 = client_incoming
780 .next()
781 .await
782 .unwrap()
783 .into_any()
784 .downcast::<TypedEnvelope<proto::Error>>()
785 .unwrap();
786 events.lock().push(incoming2.payload.message);
787
788 // Prevent the connection from being dropped
789 client_incoming.next().await;
790 }
791 })
792 .detach();
793
794 // Allow the request to make some progress before dropping it.
795 cx.background().simulate_random_delay().await;
796 drop(request1_task);
797
798 request2_task.await;
799 assert_eq!(
800 &*events.lock(),
801 &[
802 "message 1".to_string(),
803 "message 2".to_string(),
804 "response 2".to_string()
805 ]
806 );
807 }
808
809 #[gpui::test(iterations = 50)]
810 async fn test_disconnect(cx: &mut TestAppContext) {
811 let executor = cx.foreground();
812
813 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
814
815 let client = Peer::new();
816 let (connection_id, io_handler, mut incoming) = client
817 .add_test_connection(client_conn, cx.background())
818 .await;
819
820 let (io_ended_tx, io_ended_rx) = oneshot::channel();
821 executor
822 .spawn(async move {
823 io_handler.await.ok();
824 io_ended_tx.send(()).unwrap();
825 })
826 .detach();
827
828 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
829 executor
830 .spawn(async move {
831 incoming.next().await;
832 messages_ended_tx.send(()).unwrap();
833 })
834 .detach();
835
836 client.disconnect(connection_id);
837
838 let _ = io_ended_rx.await;
839 let _ = messages_ended_rx.await;
840 assert!(server_conn
841 .send(WebSocketMessage::Binary(vec![]))
842 .await
843 .is_err());
844 }
845
846 #[gpui::test(iterations = 50)]
847 async fn test_io_error(cx: &mut TestAppContext) {
848 let executor = cx.foreground();
849 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
850
851 let client = Peer::new();
852 let (connection_id, io_handler, mut incoming) = client
853 .add_test_connection(client_conn, cx.background())
854 .await;
855 executor.spawn(io_handler).detach();
856 executor
857 .spawn(async move { incoming.next().await })
858 .detach();
859
860 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
861 let _request = server_conn.rx.next().await.unwrap().unwrap();
862
863 drop(server_conn);
864 assert_eq!(
865 response.await.unwrap_err().to_string(),
866 "connection was closed"
867 );
868 }
869}