1use crate::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
2
3use super::{
4 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
5 Connection,
6};
7use anyhow::{anyhow, Context, Result};
8use collections::HashMap;
9use futures::{
10 channel::{mpsc, oneshot},
11 stream::BoxStream,
12 FutureExt, SinkExt, StreamExt, TryFutureExt,
13};
14use parking_lot::{Mutex, RwLock};
15use serde::{ser::SerializeStruct, Serialize};
16use std::{fmt, sync::atomic::Ordering::SeqCst, time::Instant};
17use std::{
18 future::Future,
19 marker::PhantomData,
20 sync::{
21 atomic::{self, AtomicU32},
22 Arc,
23 },
24 time::Duration,
25};
26use tracing::instrument;
27
28#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
29pub struct ConnectionId {
30 pub owner_id: u32,
31 pub id: u32,
32}
33
34impl Into<PeerId> for ConnectionId {
35 fn into(self) -> PeerId {
36 PeerId {
37 owner_id: self.owner_id,
38 id: self.id,
39 }
40 }
41}
42
43impl From<PeerId> for ConnectionId {
44 fn from(peer_id: PeerId) -> Self {
45 Self {
46 owner_id: peer_id.owner_id,
47 id: peer_id.id,
48 }
49 }
50}
51
52impl fmt::Display for ConnectionId {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 write!(f, "{}/{}", self.owner_id, self.id)
55 }
56}
57
58pub struct Receipt<T> {
59 pub sender_id: ConnectionId,
60 pub message_id: u32,
61 payload_type: PhantomData<T>,
62}
63
64impl<T> Clone for Receipt<T> {
65 fn clone(&self) -> Self {
66 *self
67 }
68}
69
70impl<T> Copy for Receipt<T> {}
71
72#[derive(Clone, Debug)]
73pub struct TypedEnvelope<T> {
74 pub sender_id: ConnectionId,
75 pub original_sender_id: Option<PeerId>,
76 pub message_id: u32,
77 pub payload: T,
78 pub received_at: Instant,
79}
80
81impl<T> TypedEnvelope<T> {
82 pub fn original_sender_id(&self) -> Result<PeerId> {
83 self.original_sender_id
84 .ok_or_else(|| anyhow!("missing original_sender_id"))
85 }
86}
87
88impl<T: RequestMessage> TypedEnvelope<T> {
89 pub fn receipt(&self) -> Receipt<T> {
90 Receipt {
91 sender_id: self.sender_id,
92 message_id: self.message_id,
93 payload_type: PhantomData,
94 }
95 }
96}
97
98pub struct Peer {
99 epoch: AtomicU32,
100 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
101 next_connection_id: AtomicU32,
102}
103
104#[derive(Clone, Serialize)]
105pub struct ConnectionState {
106 #[serde(skip)]
107 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
108 next_message_id: Arc<AtomicU32>,
109 #[allow(clippy::type_complexity)]
110 #[serde(skip)]
111 response_channels: Arc<
112 Mutex<
113 Option<
114 HashMap<
115 u32,
116 oneshot::Sender<(proto::Envelope, std::time::Instant, oneshot::Sender<()>)>,
117 >,
118 >,
119 >,
120 >,
121}
122
123const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
124const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
125pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
126
127impl Peer {
128 pub fn new(epoch: u32) -> Arc<Self> {
129 Arc::new(Self {
130 epoch: AtomicU32::new(epoch),
131 connections: Default::default(),
132 next_connection_id: Default::default(),
133 })
134 }
135
136 pub fn epoch(&self) -> u32 {
137 self.epoch.load(SeqCst)
138 }
139
140 #[instrument(skip_all)]
141 pub fn add_connection<F, Fut, Out>(
142 self: &Arc<Self>,
143 connection: Connection,
144 create_timer: F,
145 ) -> (
146 ConnectionId,
147 impl Future<Output = anyhow::Result<()>> + Send,
148 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
149 )
150 where
151 F: Send + Fn(Duration) -> Fut,
152 Fut: Send + Future<Output = Out>,
153 Out: Send,
154 {
155 // For outgoing messages, use an unbounded channel so that application code
156 // can always send messages without yielding. For incoming messages, use a
157 // bounded channel so that other peers will receive backpressure if they send
158 // messages faster than this peer can process them.
159 #[cfg(any(test, feature = "test-support"))]
160 const INCOMING_BUFFER_SIZE: usize = 1;
161 #[cfg(not(any(test, feature = "test-support")))]
162 const INCOMING_BUFFER_SIZE: usize = 256;
163 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
164 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
165
166 let connection_id = ConnectionId {
167 owner_id: self.epoch.load(SeqCst),
168 id: self.next_connection_id.fetch_add(1, SeqCst),
169 };
170 let connection_state = ConnectionState {
171 outgoing_tx,
172 next_message_id: Default::default(),
173 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
174 };
175 let mut writer = MessageStream::new(connection.tx);
176 let mut reader = MessageStream::new(connection.rx);
177
178 let this = self.clone();
179 let response_channels = connection_state.response_channels.clone();
180 let handle_io = async move {
181 tracing::trace!(%connection_id, "handle io future: start");
182
183 let _end_connection = util::defer(|| {
184 response_channels.lock().take();
185 this.connections.write().remove(&connection_id);
186 tracing::trace!(%connection_id, "handle io future: end");
187 });
188
189 // Send messages on this frequency so the connection isn't closed.
190 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
191 futures::pin_mut!(keepalive_timer);
192
193 // Disconnect if we don't receive messages at least this frequently.
194 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
195 futures::pin_mut!(receive_timeout);
196
197 loop {
198 tracing::trace!(%connection_id, "outer loop iteration start");
199 let read_message = reader.read().fuse();
200 futures::pin_mut!(read_message);
201
202 loop {
203 tracing::trace!(%connection_id, "inner loop iteration start");
204 futures::select_biased! {
205 outgoing = outgoing_rx.next().fuse() => match outgoing {
206 Some(outgoing) => {
207 tracing::trace!(%connection_id, "outgoing rpc message: writing");
208 futures::select_biased! {
209 result = writer.write(outgoing).fuse() => {
210 tracing::trace!(%connection_id, "outgoing rpc message: done writing");
211 result.context("failed to write RPC message")?;
212 tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
213 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
214 }
215 _ = create_timer(WRITE_TIMEOUT).fuse() => {
216 tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
217 Err(anyhow!("timed out writing message"))?;
218 }
219 }
220 }
221 None => {
222 tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
223 return Ok(())
224 },
225 },
226 _ = keepalive_timer => {
227 tracing::trace!(%connection_id, "keepalive interval: pinging");
228 futures::select_biased! {
229 result = writer.write(proto::Message::Ping).fuse() => {
230 tracing::trace!(%connection_id, "keepalive interval: done pinging");
231 result.context("failed to send keepalive")?;
232 tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
233 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
234 }
235 _ = create_timer(WRITE_TIMEOUT).fuse() => {
236 tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
237 Err(anyhow!("timed out sending keepalive"))?;
238 }
239 }
240 }
241 incoming = read_message => {
242 let incoming = incoming.context("error reading rpc message from socket")?;
243 tracing::trace!(%connection_id, "incoming rpc message: received");
244 tracing::trace!(%connection_id, "receive timeout: resetting");
245 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
246 if let (proto::Message::Envelope(incoming), received_at) = incoming {
247 tracing::trace!(%connection_id, "incoming rpc message: processing");
248 futures::select_biased! {
249 result = incoming_tx.send((incoming, received_at)).fuse() => match result {
250 Ok(_) => {
251 tracing::trace!(%connection_id, "incoming rpc message: processed");
252 }
253 Err(_) => {
254 tracing::trace!(%connection_id, "incoming rpc message: channel closed");
255 return Ok(())
256 }
257 },
258 _ = create_timer(WRITE_TIMEOUT).fuse() => {
259 tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
260 Err(anyhow!("timed out processing incoming message"))?
261 }
262 }
263 }
264 break;
265 },
266 _ = receive_timeout => {
267 tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
268 Err(anyhow!("delay between messages too long"))?
269 }
270 }
271 }
272 }
273 };
274
275 let response_channels = connection_state.response_channels.clone();
276 self.connections
277 .write()
278 .insert(connection_id, connection_state);
279
280 let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
281 let response_channels = response_channels.clone();
282 async move {
283 let message_id = incoming.id;
284 tracing::trace!(?incoming, "incoming message future: start");
285 let _end = util::defer(move || {
286 tracing::trace!(%connection_id, message_id, "incoming message future: end");
287 });
288
289 if let Some(responding_to) = incoming.responding_to {
290 tracing::trace!(
291 %connection_id,
292 message_id,
293 responding_to,
294 "incoming response: received"
295 );
296 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
297 if let Some(tx) = channel {
298 let requester_resumed = oneshot::channel();
299 if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
300 tracing::trace!(
301 %connection_id,
302 message_id,
303 responding_to = responding_to,
304 ?error,
305 "incoming response: request future dropped",
306 );
307 }
308
309 tracing::trace!(
310 %connection_id,
311 message_id,
312 responding_to,
313 "incoming response: waiting to resume requester"
314 );
315 let _ = requester_resumed.1.await;
316 tracing::trace!(
317 %connection_id,
318 message_id,
319 responding_to,
320 "incoming response: requester resumed"
321 );
322 } else {
323 let message_type =
324 proto::build_typed_envelope(connection_id, received_at, incoming)
325 .map(|p| p.payload_type_name());
326 tracing::warn!(
327 %connection_id,
328 message_id,
329 responding_to,
330 message_type,
331 "incoming response: unknown request"
332 );
333 }
334
335 None
336 } else {
337 tracing::trace!(%connection_id, message_id, "incoming message: received");
338 proto::build_typed_envelope(connection_id, received_at, incoming).or_else(
339 || {
340 tracing::error!(
341 %connection_id,
342 message_id,
343 "unable to construct a typed envelope"
344 );
345 None
346 },
347 )
348 }
349 }
350 });
351 (connection_id, handle_io, incoming_rx.boxed())
352 }
353
354 #[cfg(any(test, feature = "test-support"))]
355 pub fn add_test_connection(
356 self: &Arc<Self>,
357 connection: Connection,
358 executor: gpui::BackgroundExecutor,
359 ) -> (
360 ConnectionId,
361 impl Future<Output = anyhow::Result<()>> + Send,
362 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
363 ) {
364 let executor = executor.clone();
365 self.add_connection(connection, move |duration| executor.timer(duration))
366 }
367
368 pub fn disconnect(&self, connection_id: ConnectionId) {
369 self.connections.write().remove(&connection_id);
370 }
371
372 #[cfg(any(test, feature = "test-support"))]
373 pub fn reset(&self, epoch: u32) {
374 self.next_connection_id.store(0, SeqCst);
375 self.epoch.store(epoch, SeqCst);
376 }
377
378 pub fn teardown(&self) {
379 self.connections.write().clear();
380 }
381
382 pub fn request<T: RequestMessage>(
383 &self,
384 receiver_id: ConnectionId,
385 request: T,
386 ) -> impl Future<Output = Result<T::Response>> {
387 self.request_internal(None, receiver_id, request)
388 .map_ok(|envelope| envelope.payload)
389 }
390
391 pub fn request_envelope<T: RequestMessage>(
392 &self,
393 receiver_id: ConnectionId,
394 request: T,
395 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
396 self.request_internal(None, receiver_id, request)
397 }
398
399 pub fn forward_request<T: RequestMessage>(
400 &self,
401 sender_id: ConnectionId,
402 receiver_id: ConnectionId,
403 request: T,
404 ) -> impl Future<Output = Result<T::Response>> {
405 self.request_internal(Some(sender_id), receiver_id, request)
406 .map_ok(|envelope| envelope.payload)
407 }
408
409 pub fn request_internal<T: RequestMessage>(
410 &self,
411 original_sender_id: Option<ConnectionId>,
412 receiver_id: ConnectionId,
413 request: T,
414 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
415 let (tx, rx) = oneshot::channel();
416 let send = self.connection_state(receiver_id).and_then(|connection| {
417 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
418 connection
419 .response_channels
420 .lock()
421 .as_mut()
422 .ok_or_else(|| anyhow!("connection was closed"))?
423 .insert(message_id, tx);
424 connection
425 .outgoing_tx
426 .unbounded_send(proto::Message::Envelope(request.into_envelope(
427 message_id,
428 None,
429 original_sender_id.map(Into::into),
430 )))
431 .map_err(|_| anyhow!("connection was closed"))?;
432 Ok(())
433 });
434 async move {
435 send?;
436 let (response, received_at, _barrier) =
437 rx.await.map_err(|_| anyhow!("connection was closed"))?;
438
439 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
440 Err(RpcError::from_proto(&error, T::NAME))
441 } else {
442 Ok(TypedEnvelope {
443 message_id: response.id,
444 sender_id: receiver_id,
445 original_sender_id: response.original_sender_id,
446 payload: T::Response::from_envelope(response)
447 .ok_or_else(|| anyhow!("received response of the wrong type"))?,
448 received_at,
449 })
450 }
451 }
452 }
453
454 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
455 let connection = self.connection_state(receiver_id)?;
456 let message_id = connection
457 .next_message_id
458 .fetch_add(1, atomic::Ordering::SeqCst);
459 connection
460 .outgoing_tx
461 .unbounded_send(proto::Message::Envelope(
462 message.into_envelope(message_id, None, None),
463 ))?;
464 Ok(())
465 }
466
467 pub fn forward_send<T: EnvelopedMessage>(
468 &self,
469 sender_id: ConnectionId,
470 receiver_id: ConnectionId,
471 message: T,
472 ) -> Result<()> {
473 let connection = self.connection_state(receiver_id)?;
474 let message_id = connection
475 .next_message_id
476 .fetch_add(1, atomic::Ordering::SeqCst);
477 connection
478 .outgoing_tx
479 .unbounded_send(proto::Message::Envelope(message.into_envelope(
480 message_id,
481 None,
482 Some(sender_id.into()),
483 )))?;
484 Ok(())
485 }
486
487 pub fn respond<T: RequestMessage>(
488 &self,
489 receipt: Receipt<T>,
490 response: T::Response,
491 ) -> Result<()> {
492 let connection = self.connection_state(receipt.sender_id)?;
493 let message_id = connection
494 .next_message_id
495 .fetch_add(1, atomic::Ordering::SeqCst);
496 connection
497 .outgoing_tx
498 .unbounded_send(proto::Message::Envelope(response.into_envelope(
499 message_id,
500 Some(receipt.message_id),
501 None,
502 )))?;
503 Ok(())
504 }
505
506 pub fn respond_with_error<T: RequestMessage>(
507 &self,
508 receipt: Receipt<T>,
509 response: proto::Error,
510 ) -> Result<()> {
511 let connection = self.connection_state(receipt.sender_id)?;
512 let message_id = connection
513 .next_message_id
514 .fetch_add(1, atomic::Ordering::SeqCst);
515 connection
516 .outgoing_tx
517 .unbounded_send(proto::Message::Envelope(response.into_envelope(
518 message_id,
519 Some(receipt.message_id),
520 None,
521 )))?;
522 Ok(())
523 }
524
525 pub fn respond_with_unhandled_message(
526 &self,
527 envelope: Box<dyn AnyTypedEnvelope>,
528 ) -> Result<()> {
529 let connection = self.connection_state(envelope.sender_id())?;
530 let response = ErrorCode::Internal
531 .message(format!(
532 "message {} was not handled",
533 envelope.payload_type_name()
534 ))
535 .to_proto();
536 let message_id = connection
537 .next_message_id
538 .fetch_add(1, atomic::Ordering::SeqCst);
539 connection
540 .outgoing_tx
541 .unbounded_send(proto::Message::Envelope(response.into_envelope(
542 message_id,
543 Some(envelope.message_id()),
544 None,
545 )))?;
546 Ok(())
547 }
548
549 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
550 let connections = self.connections.read();
551 let connection = connections
552 .get(&connection_id)
553 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
554 Ok(connection.clone())
555 }
556}
557
558impl Serialize for Peer {
559 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
560 where
561 S: serde::Serializer,
562 {
563 let mut state = serializer.serialize_struct("Peer", 2)?;
564 state.serialize_field("connections", &*self.connections.read())?;
565 state.end()
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use crate::TypedEnvelope;
573 use async_tungstenite::tungstenite::Message as WebSocketMessage;
574 use gpui::TestAppContext;
575
576 fn init_logger() {
577 if std::env::var("RUST_LOG").is_ok() {
578 env_logger::init();
579 }
580 }
581
582 #[gpui::test(iterations = 50)]
583 async fn test_request_response(cx: &mut TestAppContext) {
584 init_logger();
585
586 let executor = cx.executor();
587
588 // create 2 clients connected to 1 server
589 let server = Peer::new(0);
590 let client1 = Peer::new(0);
591 let client2 = Peer::new(0);
592
593 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
594 Connection::in_memory(cx.executor());
595 let (client1_conn_id, io_task1, client1_incoming) =
596 client1.add_test_connection(client1_to_server_conn, cx.executor());
597 let (_, io_task2, server_incoming1) =
598 server.add_test_connection(server_to_client_1_conn, cx.executor());
599
600 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
601 Connection::in_memory(cx.executor());
602 let (client2_conn_id, io_task3, client2_incoming) =
603 client2.add_test_connection(client2_to_server_conn, cx.executor());
604 let (_, io_task4, server_incoming2) =
605 server.add_test_connection(server_to_client_2_conn, cx.executor());
606
607 executor.spawn(io_task1).detach();
608 executor.spawn(io_task2).detach();
609 executor.spawn(io_task3).detach();
610 executor.spawn(io_task4).detach();
611 executor
612 .spawn(handle_messages(server_incoming1, server.clone()))
613 .detach();
614 executor
615 .spawn(handle_messages(client1_incoming, client1.clone()))
616 .detach();
617 executor
618 .spawn(handle_messages(server_incoming2, server.clone()))
619 .detach();
620 executor
621 .spawn(handle_messages(client2_incoming, client2.clone()))
622 .detach();
623
624 assert_eq!(
625 client1
626 .request(client1_conn_id, proto::Ping {},)
627 .await
628 .unwrap(),
629 proto::Ack {}
630 );
631
632 assert_eq!(
633 client2
634 .request(client2_conn_id, proto::Ping {},)
635 .await
636 .unwrap(),
637 proto::Ack {}
638 );
639
640 assert_eq!(
641 client1
642 .request(client1_conn_id, proto::Test { id: 1 },)
643 .await
644 .unwrap(),
645 proto::Test { id: 1 }
646 );
647
648 assert_eq!(
649 client2
650 .request(client2_conn_id, proto::Test { id: 2 })
651 .await
652 .unwrap(),
653 proto::Test { id: 2 }
654 );
655
656 client1.disconnect(client1_conn_id);
657 client2.disconnect(client1_conn_id);
658
659 async fn handle_messages(
660 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
661 peer: Arc<Peer>,
662 ) -> Result<()> {
663 while let Some(envelope) = messages.next().await {
664 let envelope = envelope.into_any();
665 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
666 let receipt = envelope.receipt();
667 peer.respond(receipt, proto::Ack {})?
668 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
669 {
670 peer.respond(envelope.receipt(), envelope.payload.clone())?
671 } else {
672 panic!("unknown message type");
673 }
674 }
675
676 Ok(())
677 }
678 }
679
680 #[gpui::test(iterations = 50)]
681 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
682 let executor = cx.executor();
683 let server = Peer::new(0);
684 let client = Peer::new(0);
685
686 let (client_to_server_conn, server_to_client_conn, _kill) =
687 Connection::in_memory(executor.clone());
688 let (client_to_server_conn_id, io_task1, mut client_incoming) =
689 client.add_test_connection(client_to_server_conn, executor.clone());
690
691 let (server_to_client_conn_id, io_task2, mut server_incoming) =
692 server.add_test_connection(server_to_client_conn, executor.clone());
693
694 executor.spawn(io_task1).detach();
695 executor.spawn(io_task2).detach();
696
697 executor
698 .spawn(async move {
699 let future = server_incoming.next().await;
700 let request = future
701 .unwrap()
702 .into_any()
703 .downcast::<TypedEnvelope<proto::Ping>>()
704 .unwrap();
705
706 server
707 .send(
708 server_to_client_conn_id,
709 ErrorCode::Internal
710 .message("message 1".to_string())
711 .to_proto(),
712 )
713 .unwrap();
714 server
715 .send(
716 server_to_client_conn_id,
717 ErrorCode::Internal
718 .message("message 2".to_string())
719 .to_proto(),
720 )
721 .unwrap();
722 server.respond(request.receipt(), proto::Ack {}).unwrap();
723
724 // Prevent the connection from being dropped
725 server_incoming.next().await;
726 })
727 .detach();
728
729 let events = Arc::new(Mutex::new(Vec::new()));
730
731 let response = client.request(client_to_server_conn_id, proto::Ping {});
732 let response_task = executor.spawn({
733 let events = events.clone();
734 async move {
735 response.await.unwrap();
736 events.lock().push("response".to_string());
737 }
738 });
739
740 executor
741 .spawn({
742 let events = events.clone();
743 async move {
744 let incoming1 = client_incoming
745 .next()
746 .await
747 .unwrap()
748 .into_any()
749 .downcast::<TypedEnvelope<proto::Error>>()
750 .unwrap();
751 events.lock().push(incoming1.payload.message);
752 let incoming2 = client_incoming
753 .next()
754 .await
755 .unwrap()
756 .into_any()
757 .downcast::<TypedEnvelope<proto::Error>>()
758 .unwrap();
759 events.lock().push(incoming2.payload.message);
760
761 // Prevent the connection from being dropped
762 client_incoming.next().await;
763 }
764 })
765 .detach();
766
767 response_task.await;
768 assert_eq!(
769 &*events.lock(),
770 &[
771 "message 1".to_string(),
772 "message 2".to_string(),
773 "response".to_string()
774 ]
775 );
776 }
777
778 #[gpui::test(iterations = 50)]
779 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
780 let executor = cx.executor();
781 let server = Peer::new(0);
782 let client = Peer::new(0);
783
784 let (client_to_server_conn, server_to_client_conn, _kill) =
785 Connection::in_memory(cx.executor());
786 let (client_to_server_conn_id, io_task1, mut client_incoming) =
787 client.add_test_connection(client_to_server_conn, cx.executor());
788 let (server_to_client_conn_id, io_task2, mut server_incoming) =
789 server.add_test_connection(server_to_client_conn, cx.executor());
790
791 executor.spawn(io_task1).detach();
792 executor.spawn(io_task2).detach();
793
794 executor
795 .spawn(async move {
796 let request1 = server_incoming
797 .next()
798 .await
799 .unwrap()
800 .into_any()
801 .downcast::<TypedEnvelope<proto::Ping>>()
802 .unwrap();
803 let request2 = server_incoming
804 .next()
805 .await
806 .unwrap()
807 .into_any()
808 .downcast::<TypedEnvelope<proto::Ping>>()
809 .unwrap();
810
811 server
812 .send(
813 server_to_client_conn_id,
814 ErrorCode::Internal
815 .message("message 1".to_string())
816 .to_proto(),
817 )
818 .unwrap();
819 server
820 .send(
821 server_to_client_conn_id,
822 ErrorCode::Internal
823 .message("message 2".to_string())
824 .to_proto(),
825 )
826 .unwrap();
827 server.respond(request1.receipt(), proto::Ack {}).unwrap();
828 server.respond(request2.receipt(), proto::Ack {}).unwrap();
829
830 // Prevent the connection from being dropped
831 server_incoming.next().await;
832 })
833 .detach();
834
835 let events = Arc::new(Mutex::new(Vec::new()));
836
837 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
838 let request1_task = executor.spawn(request1);
839 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
840 let request2_task = executor.spawn({
841 let events = events.clone();
842 async move {
843 request2.await.unwrap();
844 events.lock().push("response 2".to_string());
845 }
846 });
847
848 executor
849 .spawn({
850 let events = events.clone();
851 async move {
852 let incoming1 = client_incoming
853 .next()
854 .await
855 .unwrap()
856 .into_any()
857 .downcast::<TypedEnvelope<proto::Error>>()
858 .unwrap();
859 events.lock().push(incoming1.payload.message);
860 let incoming2 = client_incoming
861 .next()
862 .await
863 .unwrap()
864 .into_any()
865 .downcast::<TypedEnvelope<proto::Error>>()
866 .unwrap();
867 events.lock().push(incoming2.payload.message);
868
869 // Prevent the connection from being dropped
870 client_incoming.next().await;
871 }
872 })
873 .detach();
874
875 // Allow the request to make some progress before dropping it.
876 cx.executor().simulate_random_delay().await;
877 drop(request1_task);
878
879 request2_task.await;
880 assert_eq!(
881 &*events.lock(),
882 &[
883 "message 1".to_string(),
884 "message 2".to_string(),
885 "response 2".to_string()
886 ]
887 );
888 }
889
890 #[gpui::test(iterations = 50)]
891 async fn test_disconnect(cx: &mut TestAppContext) {
892 let executor = cx.executor();
893
894 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
895
896 let client = Peer::new(0);
897 let (connection_id, io_handler, mut incoming) =
898 client.add_test_connection(client_conn, executor.clone());
899
900 let (io_ended_tx, io_ended_rx) = oneshot::channel();
901 executor
902 .spawn(async move {
903 io_handler.await.ok();
904 io_ended_tx.send(()).unwrap();
905 })
906 .detach();
907
908 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
909 executor
910 .spawn(async move {
911 incoming.next().await;
912 messages_ended_tx.send(()).unwrap();
913 })
914 .detach();
915
916 client.disconnect(connection_id);
917
918 let _ = io_ended_rx.await;
919 let _ = messages_ended_rx.await;
920 assert!(server_conn
921 .send(WebSocketMessage::Binary(vec![]))
922 .await
923 .is_err());
924 }
925
926 #[gpui::test(iterations = 50)]
927 async fn test_io_error(cx: &mut TestAppContext) {
928 let executor = cx.executor();
929 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
930
931 let client = Peer::new(0);
932 let (connection_id, io_handler, mut incoming) =
933 client.add_test_connection(client_conn, executor.clone());
934 executor.spawn(io_handler).detach();
935 executor
936 .spawn(async move { incoming.next().await })
937 .detach();
938
939 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
940 let _request = server_conn.rx.next().await.unwrap().unwrap();
941
942 drop(server_conn);
943 assert_eq!(
944 response.await.unwrap_err().to_string(),
945 "connection was closed"
946 );
947 }
948}