1use crate::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
2
3use super::{
4 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
5 Connection,
6};
7use anyhow::{anyhow, Context, Result};
8use collections::HashMap;
9use futures::{
10 channel::{mpsc, oneshot},
11 stream::BoxStream,
12 FutureExt, SinkExt, StreamExt, TryFutureExt,
13};
14use parking_lot::{Mutex, RwLock};
15use serde::{ser::SerializeStruct, Serialize};
16use std::{fmt, sync::atomic::Ordering::SeqCst, time::Instant};
17use std::{
18 future::Future,
19 marker::PhantomData,
20 sync::{
21 atomic::{self, AtomicU32},
22 Arc,
23 },
24 time::Duration,
25};
26use tracing::instrument;
27
28#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
29pub struct ConnectionId {
30 pub owner_id: u32,
31 pub id: u32,
32}
33
34impl Into<PeerId> for ConnectionId {
35 fn into(self) -> PeerId {
36 PeerId {
37 owner_id: self.owner_id,
38 id: self.id,
39 }
40 }
41}
42
43impl From<PeerId> for ConnectionId {
44 fn from(peer_id: PeerId) -> Self {
45 Self {
46 owner_id: peer_id.owner_id,
47 id: peer_id.id,
48 }
49 }
50}
51
52impl fmt::Display for ConnectionId {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 write!(f, "{}/{}", self.owner_id, self.id)
55 }
56}
57
58pub struct Receipt<T> {
59 pub sender_id: ConnectionId,
60 pub message_id: u32,
61 payload_type: PhantomData<T>,
62}
63
64impl<T> Clone for Receipt<T> {
65 fn clone(&self) -> Self {
66 Self {
67 sender_id: self.sender_id,
68 message_id: self.message_id,
69 payload_type: PhantomData,
70 }
71 }
72}
73
74impl<T> Copy for Receipt<T> {}
75
76#[derive(Clone, Debug)]
77pub struct TypedEnvelope<T> {
78 pub sender_id: ConnectionId,
79 pub original_sender_id: Option<PeerId>,
80 pub message_id: u32,
81 pub payload: T,
82 pub received_at: Instant,
83}
84
85impl<T> TypedEnvelope<T> {
86 pub fn original_sender_id(&self) -> Result<PeerId> {
87 self.original_sender_id
88 .ok_or_else(|| anyhow!("missing original_sender_id"))
89 }
90}
91
92impl<T: RequestMessage> TypedEnvelope<T> {
93 pub fn receipt(&self) -> Receipt<T> {
94 Receipt {
95 sender_id: self.sender_id,
96 message_id: self.message_id,
97 payload_type: PhantomData,
98 }
99 }
100}
101
102pub struct Peer {
103 epoch: AtomicU32,
104 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
105 next_connection_id: AtomicU32,
106}
107
108#[derive(Clone, Serialize)]
109pub struct ConnectionState {
110 #[serde(skip)]
111 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
112 next_message_id: Arc<AtomicU32>,
113 #[allow(clippy::type_complexity)]
114 #[serde(skip)]
115 response_channels: Arc<
116 Mutex<
117 Option<
118 HashMap<
119 u32,
120 oneshot::Sender<(proto::Envelope, std::time::Instant, oneshot::Sender<()>)>,
121 >,
122 >,
123 >,
124 >,
125}
126
127const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
128const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
129pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
130
131impl Peer {
132 pub fn new(epoch: u32) -> Arc<Self> {
133 Arc::new(Self {
134 epoch: AtomicU32::new(epoch),
135 connections: Default::default(),
136 next_connection_id: Default::default(),
137 })
138 }
139
140 pub fn epoch(&self) -> u32 {
141 self.epoch.load(SeqCst)
142 }
143
144 #[instrument(skip_all)]
145 pub fn add_connection<F, Fut, Out>(
146 self: &Arc<Self>,
147 connection: Connection,
148 create_timer: F,
149 ) -> (
150 ConnectionId,
151 impl Future<Output = anyhow::Result<()>> + Send,
152 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
153 )
154 where
155 F: Send + Fn(Duration) -> Fut,
156 Fut: Send + Future<Output = Out>,
157 Out: Send,
158 {
159 // For outgoing messages, use an unbounded channel so that application code
160 // can always send messages without yielding. For incoming messages, use a
161 // bounded channel so that other peers will receive backpressure if they send
162 // messages faster than this peer can process them.
163 #[cfg(any(test, feature = "test-support"))]
164 const INCOMING_BUFFER_SIZE: usize = 1;
165 #[cfg(not(any(test, feature = "test-support")))]
166 const INCOMING_BUFFER_SIZE: usize = 256;
167 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
168 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
169
170 let connection_id = ConnectionId {
171 owner_id: self.epoch.load(SeqCst),
172 id: self.next_connection_id.fetch_add(1, SeqCst),
173 };
174 let connection_state = ConnectionState {
175 outgoing_tx,
176 next_message_id: Default::default(),
177 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
178 };
179 let mut writer = MessageStream::new(connection.tx);
180 let mut reader = MessageStream::new(connection.rx);
181
182 let this = self.clone();
183 let response_channels = connection_state.response_channels.clone();
184 let handle_io = async move {
185 tracing::trace!(%connection_id, "handle io future: start");
186
187 let _end_connection = util::defer(|| {
188 response_channels.lock().take();
189 this.connections.write().remove(&connection_id);
190 tracing::trace!(%connection_id, "handle io future: end");
191 });
192
193 // Send messages on this frequency so the connection isn't closed.
194 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
195 futures::pin_mut!(keepalive_timer);
196
197 // Disconnect if we don't receive messages at least this frequently.
198 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
199 futures::pin_mut!(receive_timeout);
200
201 loop {
202 tracing::trace!(%connection_id, "outer loop iteration start");
203 let read_message = reader.read().fuse();
204 futures::pin_mut!(read_message);
205
206 loop {
207 tracing::trace!(%connection_id, "inner loop iteration start");
208 futures::select_biased! {
209 outgoing = outgoing_rx.next().fuse() => match outgoing {
210 Some(outgoing) => {
211 tracing::trace!(%connection_id, "outgoing rpc message: writing");
212 futures::select_biased! {
213 result = writer.write(outgoing).fuse() => {
214 tracing::trace!(%connection_id, "outgoing rpc message: done writing");
215 result.context("failed to write RPC message")?;
216 tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
217 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
218 }
219 _ = create_timer(WRITE_TIMEOUT).fuse() => {
220 tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
221 Err(anyhow!("timed out writing message"))?;
222 }
223 }
224 }
225 None => {
226 tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
227 return Ok(())
228 },
229 },
230 _ = keepalive_timer => {
231 tracing::trace!(%connection_id, "keepalive interval: pinging");
232 futures::select_biased! {
233 result = writer.write(proto::Message::Ping).fuse() => {
234 tracing::trace!(%connection_id, "keepalive interval: done pinging");
235 result.context("failed to send keepalive")?;
236 tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
237 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
238 }
239 _ = create_timer(WRITE_TIMEOUT).fuse() => {
240 tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
241 Err(anyhow!("timed out sending keepalive"))?;
242 }
243 }
244 }
245 incoming = read_message => {
246 let incoming = incoming.context("error reading rpc message from socket")?;
247 tracing::trace!(%connection_id, "incoming rpc message: received");
248 tracing::trace!(%connection_id, "receive timeout: resetting");
249 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
250 if let (proto::Message::Envelope(incoming), received_at) = incoming {
251 tracing::trace!(%connection_id, "incoming rpc message: processing");
252 futures::select_biased! {
253 result = incoming_tx.send((incoming, received_at)).fuse() => match result {
254 Ok(_) => {
255 tracing::trace!(%connection_id, "incoming rpc message: processed");
256 }
257 Err(_) => {
258 tracing::trace!(%connection_id, "incoming rpc message: channel closed");
259 return Ok(())
260 }
261 },
262 _ = create_timer(WRITE_TIMEOUT).fuse() => {
263 tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
264 Err(anyhow!("timed out processing incoming message"))?
265 }
266 }
267 }
268 break;
269 },
270 _ = receive_timeout => {
271 tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
272 Err(anyhow!("delay between messages too long"))?
273 }
274 }
275 }
276 }
277 };
278
279 let response_channels = connection_state.response_channels.clone();
280 self.connections
281 .write()
282 .insert(connection_id, connection_state);
283
284 let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
285 let response_channels = response_channels.clone();
286 async move {
287 let message_id = incoming.id;
288 tracing::trace!(?incoming, "incoming message future: start");
289 let _end = util::defer(move || {
290 tracing::trace!(%connection_id, message_id, "incoming message future: end");
291 });
292
293 if let Some(responding_to) = incoming.responding_to {
294 tracing::trace!(
295 %connection_id,
296 message_id,
297 responding_to,
298 "incoming response: received"
299 );
300 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
301 if let Some(tx) = channel {
302 let requester_resumed = oneshot::channel();
303 if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
304 tracing::trace!(
305 %connection_id,
306 message_id,
307 responding_to = responding_to,
308 ?error,
309 "incoming response: request future dropped",
310 );
311 }
312
313 tracing::trace!(
314 %connection_id,
315 message_id,
316 responding_to,
317 "incoming response: waiting to resume requester"
318 );
319 let _ = requester_resumed.1.await;
320 tracing::trace!(
321 %connection_id,
322 message_id,
323 responding_to,
324 "incoming response: requester resumed"
325 );
326 } else {
327 let message_type =
328 proto::build_typed_envelope(connection_id, received_at, incoming)
329 .map(|p| p.payload_type_name());
330 tracing::warn!(
331 %connection_id,
332 message_id,
333 responding_to,
334 message_type,
335 "incoming response: unknown request"
336 );
337 }
338
339 None
340 } else {
341 tracing::trace!(%connection_id, message_id, "incoming message: received");
342 proto::build_typed_envelope(connection_id, received_at, incoming).or_else(
343 || {
344 tracing::error!(
345 %connection_id,
346 message_id,
347 "unable to construct a typed envelope"
348 );
349 None
350 },
351 )
352 }
353 }
354 });
355 (connection_id, handle_io, incoming_rx.boxed())
356 }
357
358 #[cfg(any(test, feature = "test-support"))]
359 pub fn add_test_connection(
360 self: &Arc<Self>,
361 connection: Connection,
362 executor: gpui::BackgroundExecutor,
363 ) -> (
364 ConnectionId,
365 impl Future<Output = anyhow::Result<()>> + Send,
366 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
367 ) {
368 let executor = executor.clone();
369 self.add_connection(connection, move |duration| executor.timer(duration))
370 }
371
372 pub fn disconnect(&self, connection_id: ConnectionId) {
373 self.connections.write().remove(&connection_id);
374 }
375
376 #[cfg(any(test, feature = "test-support"))]
377 pub fn reset(&self, epoch: u32) {
378 self.next_connection_id.store(0, SeqCst);
379 self.epoch.store(epoch, SeqCst);
380 }
381
382 pub fn teardown(&self) {
383 self.connections.write().clear();
384 }
385
386 pub fn request<T: RequestMessage>(
387 &self,
388 receiver_id: ConnectionId,
389 request: T,
390 ) -> impl Future<Output = Result<T::Response>> {
391 self.request_internal(None, receiver_id, request)
392 .map_ok(|envelope| envelope.payload)
393 }
394
395 pub fn request_envelope<T: RequestMessage>(
396 &self,
397 receiver_id: ConnectionId,
398 request: T,
399 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
400 self.request_internal(None, receiver_id, request)
401 }
402
403 pub fn forward_request<T: RequestMessage>(
404 &self,
405 sender_id: ConnectionId,
406 receiver_id: ConnectionId,
407 request: T,
408 ) -> impl Future<Output = Result<T::Response>> {
409 self.request_internal(Some(sender_id), receiver_id, request)
410 .map_ok(|envelope| envelope.payload)
411 }
412
413 pub fn request_internal<T: RequestMessage>(
414 &self,
415 original_sender_id: Option<ConnectionId>,
416 receiver_id: ConnectionId,
417 request: T,
418 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
419 let (tx, rx) = oneshot::channel();
420 let send = self.connection_state(receiver_id).and_then(|connection| {
421 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
422 connection
423 .response_channels
424 .lock()
425 .as_mut()
426 .ok_or_else(|| anyhow!("connection was closed"))?
427 .insert(message_id, tx);
428 connection
429 .outgoing_tx
430 .unbounded_send(proto::Message::Envelope(request.into_envelope(
431 message_id,
432 None,
433 original_sender_id.map(Into::into),
434 )))
435 .map_err(|_| anyhow!("connection was closed"))?;
436 Ok(())
437 });
438 async move {
439 send?;
440 let (response, received_at, _barrier) =
441 rx.await.map_err(|_| anyhow!("connection was closed"))?;
442
443 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
444 Err(RpcError::from_proto(&error, T::NAME))
445 } else {
446 Ok(TypedEnvelope {
447 message_id: response.id,
448 sender_id: receiver_id,
449 original_sender_id: response.original_sender_id,
450 payload: T::Response::from_envelope(response)
451 .ok_or_else(|| anyhow!("received response of the wrong type"))?,
452 received_at,
453 })
454 }
455 }
456 }
457
458 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
459 let connection = self.connection_state(receiver_id)?;
460 let message_id = connection
461 .next_message_id
462 .fetch_add(1, atomic::Ordering::SeqCst);
463 connection
464 .outgoing_tx
465 .unbounded_send(proto::Message::Envelope(
466 message.into_envelope(message_id, None, None),
467 ))?;
468 Ok(())
469 }
470
471 pub fn forward_send<T: EnvelopedMessage>(
472 &self,
473 sender_id: ConnectionId,
474 receiver_id: ConnectionId,
475 message: T,
476 ) -> Result<()> {
477 let connection = self.connection_state(receiver_id)?;
478 let message_id = connection
479 .next_message_id
480 .fetch_add(1, atomic::Ordering::SeqCst);
481 connection
482 .outgoing_tx
483 .unbounded_send(proto::Message::Envelope(message.into_envelope(
484 message_id,
485 None,
486 Some(sender_id.into()),
487 )))?;
488 Ok(())
489 }
490
491 pub fn respond<T: RequestMessage>(
492 &self,
493 receipt: Receipt<T>,
494 response: T::Response,
495 ) -> Result<()> {
496 let connection = self.connection_state(receipt.sender_id)?;
497 let message_id = connection
498 .next_message_id
499 .fetch_add(1, atomic::Ordering::SeqCst);
500 connection
501 .outgoing_tx
502 .unbounded_send(proto::Message::Envelope(response.into_envelope(
503 message_id,
504 Some(receipt.message_id),
505 None,
506 )))?;
507 Ok(())
508 }
509
510 pub fn respond_with_error<T: RequestMessage>(
511 &self,
512 receipt: Receipt<T>,
513 response: proto::Error,
514 ) -> Result<()> {
515 let connection = self.connection_state(receipt.sender_id)?;
516 let message_id = connection
517 .next_message_id
518 .fetch_add(1, atomic::Ordering::SeqCst);
519 connection
520 .outgoing_tx
521 .unbounded_send(proto::Message::Envelope(response.into_envelope(
522 message_id,
523 Some(receipt.message_id),
524 None,
525 )))?;
526 Ok(())
527 }
528
529 pub fn respond_with_unhandled_message(
530 &self,
531 envelope: Box<dyn AnyTypedEnvelope>,
532 ) -> Result<()> {
533 let connection = self.connection_state(envelope.sender_id())?;
534 let response = ErrorCode::Internal
535 .message(format!(
536 "message {} was not handled",
537 envelope.payload_type_name()
538 ))
539 .to_proto();
540 let message_id = connection
541 .next_message_id
542 .fetch_add(1, atomic::Ordering::SeqCst);
543 connection
544 .outgoing_tx
545 .unbounded_send(proto::Message::Envelope(response.into_envelope(
546 message_id,
547 Some(envelope.message_id()),
548 None,
549 )))?;
550 Ok(())
551 }
552
553 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
554 let connections = self.connections.read();
555 let connection = connections
556 .get(&connection_id)
557 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
558 Ok(connection.clone())
559 }
560}
561
562impl Serialize for Peer {
563 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
564 where
565 S: serde::Serializer,
566 {
567 let mut state = serializer.serialize_struct("Peer", 2)?;
568 state.serialize_field("connections", &*self.connections.read())?;
569 state.end()
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576 use crate::TypedEnvelope;
577 use async_tungstenite::tungstenite::Message as WebSocketMessage;
578 use gpui::TestAppContext;
579
580 fn init_logger() {
581 if std::env::var("RUST_LOG").is_ok() {
582 env_logger::init();
583 }
584 }
585
586 #[gpui::test(iterations = 50)]
587 async fn test_request_response(cx: &mut TestAppContext) {
588 init_logger();
589
590 let executor = cx.executor();
591
592 // create 2 clients connected to 1 server
593 let server = Peer::new(0);
594 let client1 = Peer::new(0);
595 let client2 = Peer::new(0);
596
597 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
598 Connection::in_memory(cx.executor());
599 let (client1_conn_id, io_task1, client1_incoming) =
600 client1.add_test_connection(client1_to_server_conn, cx.executor());
601 let (_, io_task2, server_incoming1) =
602 server.add_test_connection(server_to_client_1_conn, cx.executor());
603
604 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
605 Connection::in_memory(cx.executor());
606 let (client2_conn_id, io_task3, client2_incoming) =
607 client2.add_test_connection(client2_to_server_conn, cx.executor());
608 let (_, io_task4, server_incoming2) =
609 server.add_test_connection(server_to_client_2_conn, cx.executor());
610
611 executor.spawn(io_task1).detach();
612 executor.spawn(io_task2).detach();
613 executor.spawn(io_task3).detach();
614 executor.spawn(io_task4).detach();
615 executor
616 .spawn(handle_messages(server_incoming1, server.clone()))
617 .detach();
618 executor
619 .spawn(handle_messages(client1_incoming, client1.clone()))
620 .detach();
621 executor
622 .spawn(handle_messages(server_incoming2, server.clone()))
623 .detach();
624 executor
625 .spawn(handle_messages(client2_incoming, client2.clone()))
626 .detach();
627
628 assert_eq!(
629 client1
630 .request(client1_conn_id, proto::Ping {},)
631 .await
632 .unwrap(),
633 proto::Ack {}
634 );
635
636 assert_eq!(
637 client2
638 .request(client2_conn_id, proto::Ping {},)
639 .await
640 .unwrap(),
641 proto::Ack {}
642 );
643
644 assert_eq!(
645 client1
646 .request(client1_conn_id, proto::Test { id: 1 },)
647 .await
648 .unwrap(),
649 proto::Test { id: 1 }
650 );
651
652 assert_eq!(
653 client2
654 .request(client2_conn_id, proto::Test { id: 2 })
655 .await
656 .unwrap(),
657 proto::Test { id: 2 }
658 );
659
660 client1.disconnect(client1_conn_id);
661 client2.disconnect(client1_conn_id);
662
663 async fn handle_messages(
664 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
665 peer: Arc<Peer>,
666 ) -> Result<()> {
667 while let Some(envelope) = messages.next().await {
668 let envelope = envelope.into_any();
669 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
670 let receipt = envelope.receipt();
671 peer.respond(receipt, proto::Ack {})?
672 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
673 {
674 peer.respond(envelope.receipt(), envelope.payload.clone())?
675 } else {
676 panic!("unknown message type");
677 }
678 }
679
680 Ok(())
681 }
682 }
683
684 #[gpui::test(iterations = 50)]
685 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
686 let executor = cx.executor();
687 let server = Peer::new(0);
688 let client = Peer::new(0);
689
690 let (client_to_server_conn, server_to_client_conn, _kill) =
691 Connection::in_memory(executor.clone());
692 let (client_to_server_conn_id, io_task1, mut client_incoming) =
693 client.add_test_connection(client_to_server_conn, executor.clone());
694
695 let (server_to_client_conn_id, io_task2, mut server_incoming) =
696 server.add_test_connection(server_to_client_conn, executor.clone());
697
698 executor.spawn(io_task1).detach();
699 executor.spawn(io_task2).detach();
700
701 executor
702 .spawn(async move {
703 let future = server_incoming.next().await;
704 let request = future
705 .unwrap()
706 .into_any()
707 .downcast::<TypedEnvelope<proto::Ping>>()
708 .unwrap();
709
710 server
711 .send(
712 server_to_client_conn_id,
713 ErrorCode::Internal
714 .message("message 1".to_string())
715 .to_proto(),
716 )
717 .unwrap();
718 server
719 .send(
720 server_to_client_conn_id,
721 ErrorCode::Internal
722 .message("message 2".to_string())
723 .to_proto(),
724 )
725 .unwrap();
726 server.respond(request.receipt(), proto::Ack {}).unwrap();
727
728 // Prevent the connection from being dropped
729 server_incoming.next().await;
730 })
731 .detach();
732
733 let events = Arc::new(Mutex::new(Vec::new()));
734
735 let response = client.request(client_to_server_conn_id, proto::Ping {});
736 let response_task = executor.spawn({
737 let events = events.clone();
738 async move {
739 response.await.unwrap();
740 events.lock().push("response".to_string());
741 }
742 });
743
744 executor
745 .spawn({
746 let events = events.clone();
747 async move {
748 let incoming1 = client_incoming
749 .next()
750 .await
751 .unwrap()
752 .into_any()
753 .downcast::<TypedEnvelope<proto::Error>>()
754 .unwrap();
755 events.lock().push(incoming1.payload.message);
756 let incoming2 = client_incoming
757 .next()
758 .await
759 .unwrap()
760 .into_any()
761 .downcast::<TypedEnvelope<proto::Error>>()
762 .unwrap();
763 events.lock().push(incoming2.payload.message);
764
765 // Prevent the connection from being dropped
766 client_incoming.next().await;
767 }
768 })
769 .detach();
770
771 response_task.await;
772 assert_eq!(
773 &*events.lock(),
774 &[
775 "message 1".to_string(),
776 "message 2".to_string(),
777 "response".to_string()
778 ]
779 );
780 }
781
782 #[gpui::test(iterations = 50)]
783 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
784 let executor = cx.executor();
785 let server = Peer::new(0);
786 let client = Peer::new(0);
787
788 let (client_to_server_conn, server_to_client_conn, _kill) =
789 Connection::in_memory(cx.executor());
790 let (client_to_server_conn_id, io_task1, mut client_incoming) =
791 client.add_test_connection(client_to_server_conn, cx.executor());
792 let (server_to_client_conn_id, io_task2, mut server_incoming) =
793 server.add_test_connection(server_to_client_conn, cx.executor());
794
795 executor.spawn(io_task1).detach();
796 executor.spawn(io_task2).detach();
797
798 executor
799 .spawn(async move {
800 let request1 = server_incoming
801 .next()
802 .await
803 .unwrap()
804 .into_any()
805 .downcast::<TypedEnvelope<proto::Ping>>()
806 .unwrap();
807 let request2 = server_incoming
808 .next()
809 .await
810 .unwrap()
811 .into_any()
812 .downcast::<TypedEnvelope<proto::Ping>>()
813 .unwrap();
814
815 server
816 .send(
817 server_to_client_conn_id,
818 ErrorCode::Internal
819 .message("message 1".to_string())
820 .to_proto(),
821 )
822 .unwrap();
823 server
824 .send(
825 server_to_client_conn_id,
826 ErrorCode::Internal
827 .message("message 2".to_string())
828 .to_proto(),
829 )
830 .unwrap();
831 server.respond(request1.receipt(), proto::Ack {}).unwrap();
832 server.respond(request2.receipt(), proto::Ack {}).unwrap();
833
834 // Prevent the connection from being dropped
835 server_incoming.next().await;
836 })
837 .detach();
838
839 let events = Arc::new(Mutex::new(Vec::new()));
840
841 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
842 let request1_task = executor.spawn(request1);
843 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
844 let request2_task = executor.spawn({
845 let events = events.clone();
846 async move {
847 request2.await.unwrap();
848 events.lock().push("response 2".to_string());
849 }
850 });
851
852 executor
853 .spawn({
854 let events = events.clone();
855 async move {
856 let incoming1 = client_incoming
857 .next()
858 .await
859 .unwrap()
860 .into_any()
861 .downcast::<TypedEnvelope<proto::Error>>()
862 .unwrap();
863 events.lock().push(incoming1.payload.message);
864 let incoming2 = client_incoming
865 .next()
866 .await
867 .unwrap()
868 .into_any()
869 .downcast::<TypedEnvelope<proto::Error>>()
870 .unwrap();
871 events.lock().push(incoming2.payload.message);
872
873 // Prevent the connection from being dropped
874 client_incoming.next().await;
875 }
876 })
877 .detach();
878
879 // Allow the request to make some progress before dropping it.
880 cx.executor().simulate_random_delay().await;
881 drop(request1_task);
882
883 request2_task.await;
884 assert_eq!(
885 &*events.lock(),
886 &[
887 "message 1".to_string(),
888 "message 2".to_string(),
889 "response 2".to_string()
890 ]
891 );
892 }
893
894 #[gpui::test(iterations = 50)]
895 async fn test_disconnect(cx: &mut TestAppContext) {
896 let executor = cx.executor();
897
898 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
899
900 let client = Peer::new(0);
901 let (connection_id, io_handler, mut incoming) =
902 client.add_test_connection(client_conn, executor.clone());
903
904 let (io_ended_tx, io_ended_rx) = oneshot::channel();
905 executor
906 .spawn(async move {
907 io_handler.await.ok();
908 io_ended_tx.send(()).unwrap();
909 })
910 .detach();
911
912 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
913 executor
914 .spawn(async move {
915 incoming.next().await;
916 messages_ended_tx.send(()).unwrap();
917 })
918 .detach();
919
920 client.disconnect(connection_id);
921
922 let _ = io_ended_rx.await;
923 let _ = messages_ended_rx.await;
924 assert!(server_conn
925 .send(WebSocketMessage::Binary(vec![]))
926 .await
927 .is_err());
928 }
929
930 #[gpui::test(iterations = 50)]
931 async fn test_io_error(cx: &mut TestAppContext) {
932 let executor = cx.executor();
933 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
934
935 let client = Peer::new(0);
936 let (connection_id, io_handler, mut incoming) =
937 client.add_test_connection(client_conn, executor.clone());
938 executor.spawn(io_handler).detach();
939 executor
940 .spawn(async move { incoming.next().await })
941 .detach();
942
943 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
944 let _request = server_conn.rx.next().await.unwrap().unwrap();
945
946 drop(server_conn);
947 assert_eq!(
948 response.await.unwrap_err().to_string(),
949 "connection was closed"
950 );
951 }
952}