1use crate::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
2
3use super::{
4 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
5 Connection,
6};
7use anyhow::{anyhow, Context, Result};
8use collections::HashMap;
9use futures::{
10 channel::{mpsc, oneshot},
11 stream::BoxStream,
12 FutureExt, SinkExt, StreamExt, TryFutureExt,
13};
14use parking_lot::{Mutex, RwLock};
15use serde::{ser::SerializeStruct, Serialize};
16use std::{fmt, sync::atomic::Ordering::SeqCst};
17use std::{
18 future::Future,
19 marker::PhantomData,
20 sync::{
21 atomic::{self, AtomicU32},
22 Arc,
23 },
24 time::Duration,
25};
26use tracing::instrument;
27
28#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
29pub struct ConnectionId {
30 pub owner_id: u32,
31 pub id: u32,
32}
33
34impl Into<PeerId> for ConnectionId {
35 fn into(self) -> PeerId {
36 PeerId {
37 owner_id: self.owner_id,
38 id: self.id,
39 }
40 }
41}
42
43impl From<PeerId> for ConnectionId {
44 fn from(peer_id: PeerId) -> Self {
45 Self {
46 owner_id: peer_id.owner_id,
47 id: peer_id.id,
48 }
49 }
50}
51
52impl fmt::Display for ConnectionId {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 write!(f, "{}/{}", self.owner_id, self.id)
55 }
56}
57
58pub struct Receipt<T> {
59 pub sender_id: ConnectionId,
60 pub message_id: u32,
61 payload_type: PhantomData<T>,
62}
63
64impl<T> Clone for Receipt<T> {
65 fn clone(&self) -> Self {
66 Self {
67 sender_id: self.sender_id,
68 message_id: self.message_id,
69 payload_type: PhantomData,
70 }
71 }
72}
73
74impl<T> Copy for Receipt<T> {}
75
76#[derive(Clone, Debug)]
77pub struct TypedEnvelope<T> {
78 pub sender_id: ConnectionId,
79 pub original_sender_id: Option<PeerId>,
80 pub message_id: u32,
81 pub payload: T,
82}
83
84impl<T> TypedEnvelope<T> {
85 pub fn original_sender_id(&self) -> Result<PeerId> {
86 self.original_sender_id
87 .ok_or_else(|| anyhow!("missing original_sender_id"))
88 }
89}
90
91impl<T: RequestMessage> TypedEnvelope<T> {
92 pub fn receipt(&self) -> Receipt<T> {
93 Receipt {
94 sender_id: self.sender_id,
95 message_id: self.message_id,
96 payload_type: PhantomData,
97 }
98 }
99}
100
101pub struct Peer {
102 epoch: AtomicU32,
103 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
104 next_connection_id: AtomicU32,
105}
106
107#[derive(Clone, Serialize)]
108pub struct ConnectionState {
109 #[serde(skip)]
110 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
111 next_message_id: Arc<AtomicU32>,
112 #[allow(clippy::type_complexity)]
113 #[serde(skip)]
114 response_channels:
115 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
116}
117
118const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
119const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
120pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
121
122impl Peer {
123 pub fn new(epoch: u32) -> Arc<Self> {
124 Arc::new(Self {
125 epoch: AtomicU32::new(epoch),
126 connections: Default::default(),
127 next_connection_id: Default::default(),
128 })
129 }
130
131 pub fn epoch(&self) -> u32 {
132 self.epoch.load(SeqCst)
133 }
134
135 #[instrument(skip_all)]
136 pub fn add_connection<F, Fut, Out>(
137 self: &Arc<Self>,
138 connection: Connection,
139 create_timer: F,
140 ) -> (
141 ConnectionId,
142 impl Future<Output = anyhow::Result<()>> + Send,
143 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
144 )
145 where
146 F: Send + Fn(Duration) -> Fut,
147 Fut: Send + Future<Output = Out>,
148 Out: Send,
149 {
150 // For outgoing messages, use an unbounded channel so that application code
151 // can always send messages without yielding. For incoming messages, use a
152 // bounded channel so that other peers will receive backpressure if they send
153 // messages faster than this peer can process them.
154 #[cfg(any(test, feature = "test-support"))]
155 const INCOMING_BUFFER_SIZE: usize = 1;
156 #[cfg(not(any(test, feature = "test-support")))]
157 const INCOMING_BUFFER_SIZE: usize = 64;
158 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
159 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
160
161 let connection_id = ConnectionId {
162 owner_id: self.epoch.load(SeqCst),
163 id: self.next_connection_id.fetch_add(1, SeqCst),
164 };
165 let connection_state = ConnectionState {
166 outgoing_tx,
167 next_message_id: Default::default(),
168 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
169 };
170 let mut writer = MessageStream::new(connection.tx);
171 let mut reader = MessageStream::new(connection.rx);
172
173 let this = self.clone();
174 let response_channels = connection_state.response_channels.clone();
175 let handle_io = async move {
176 tracing::trace!(%connection_id, "handle io future: start");
177
178 let _end_connection = util::defer(|| {
179 response_channels.lock().take();
180 this.connections.write().remove(&connection_id);
181 tracing::trace!(%connection_id, "handle io future: end");
182 });
183
184 // Send messages on this frequency so the connection isn't closed.
185 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
186 futures::pin_mut!(keepalive_timer);
187
188 // Disconnect if we don't receive messages at least this frequently.
189 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
190 futures::pin_mut!(receive_timeout);
191
192 loop {
193 tracing::trace!(%connection_id, "outer loop iteration start");
194 let read_message = reader.read().fuse();
195 futures::pin_mut!(read_message);
196
197 loop {
198 tracing::trace!(%connection_id, "inner loop iteration start");
199 futures::select_biased! {
200 outgoing = outgoing_rx.next().fuse() => match outgoing {
201 Some(outgoing) => {
202 tracing::trace!(%connection_id, "outgoing rpc message: writing");
203 futures::select_biased! {
204 result = writer.write(outgoing).fuse() => {
205 tracing::trace!(%connection_id, "outgoing rpc message: done writing");
206 result.context("failed to write RPC message")?;
207 tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
208 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
209 }
210 _ = create_timer(WRITE_TIMEOUT).fuse() => {
211 tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
212 Err(anyhow!("timed out writing message"))?;
213 }
214 }
215 }
216 None => {
217 tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
218 return Ok(())
219 },
220 },
221 _ = keepalive_timer => {
222 tracing::trace!(%connection_id, "keepalive interval: pinging");
223 futures::select_biased! {
224 result = writer.write(proto::Message::Ping).fuse() => {
225 tracing::trace!(%connection_id, "keepalive interval: done pinging");
226 result.context("failed to send keepalive")?;
227 tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
228 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
229 }
230 _ = create_timer(WRITE_TIMEOUT).fuse() => {
231 tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
232 Err(anyhow!("timed out sending keepalive"))?;
233 }
234 }
235 }
236 incoming = read_message => {
237 let incoming = incoming.context("error reading rpc message from socket")?;
238 tracing::trace!(%connection_id, "incoming rpc message: received");
239 tracing::trace!(%connection_id, "receive timeout: resetting");
240 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
241 if let proto::Message::Envelope(incoming) = incoming {
242 tracing::trace!(%connection_id, "incoming rpc message: processing");
243 futures::select_biased! {
244 result = incoming_tx.send(incoming).fuse() => match result {
245 Ok(_) => {
246 tracing::trace!(%connection_id, "incoming rpc message: processed");
247 }
248 Err(_) => {
249 tracing::trace!(%connection_id, "incoming rpc message: channel closed");
250 return Ok(())
251 }
252 },
253 _ = create_timer(WRITE_TIMEOUT).fuse() => {
254 tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
255 Err(anyhow!("timed out processing incoming message"))?
256 }
257 }
258 }
259 break;
260 },
261 _ = receive_timeout => {
262 tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
263 Err(anyhow!("delay between messages too long"))?
264 }
265 }
266 }
267 }
268 };
269
270 let response_channels = connection_state.response_channels.clone();
271 self.connections
272 .write()
273 .insert(connection_id, connection_state);
274
275 let incoming_rx = incoming_rx.filter_map(move |incoming| {
276 let response_channels = response_channels.clone();
277 async move {
278 let message_id = incoming.id;
279 tracing::trace!(?incoming, "incoming message future: start");
280 let _end = util::defer(move || {
281 tracing::trace!(%connection_id, message_id, "incoming message future: end");
282 });
283
284 if let Some(responding_to) = incoming.responding_to {
285 tracing::trace!(
286 %connection_id,
287 message_id,
288 responding_to,
289 "incoming response: received"
290 );
291 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
292 if let Some(tx) = channel {
293 let requester_resumed = oneshot::channel();
294 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
295 tracing::trace!(
296 %connection_id,
297 message_id,
298 responding_to = responding_to,
299 ?error,
300 "incoming response: request future dropped",
301 );
302 }
303
304 tracing::trace!(
305 %connection_id,
306 message_id,
307 responding_to,
308 "incoming response: waiting to resume requester"
309 );
310 let _ = requester_resumed.1.await;
311 tracing::trace!(
312 %connection_id,
313 message_id,
314 responding_to,
315 "incoming response: requester resumed"
316 );
317 } else {
318 let message_type = proto::build_typed_envelope(connection_id, incoming)
319 .map(|p| p.payload_type_name());
320 tracing::warn!(
321 %connection_id,
322 message_id,
323 responding_to,
324 message_type,
325 "incoming response: unknown request"
326 );
327 }
328
329 None
330 } else {
331 tracing::trace!(%connection_id, message_id, "incoming message: received");
332 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
333 tracing::error!(
334 %connection_id,
335 message_id,
336 "unable to construct a typed envelope"
337 );
338 None
339 })
340 }
341 }
342 });
343 (connection_id, handle_io, incoming_rx.boxed())
344 }
345
346 #[cfg(any(test, feature = "test-support"))]
347 pub fn add_test_connection(
348 self: &Arc<Self>,
349 connection: Connection,
350 executor: gpui::BackgroundExecutor,
351 ) -> (
352 ConnectionId,
353 impl Future<Output = anyhow::Result<()>> + Send,
354 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
355 ) {
356 let executor = executor.clone();
357 self.add_connection(connection, move |duration| executor.timer(duration))
358 }
359
360 pub fn disconnect(&self, connection_id: ConnectionId) {
361 self.connections.write().remove(&connection_id);
362 }
363
364 #[cfg(any(test, feature = "test-support"))]
365 pub fn reset(&self, epoch: u32) {
366 self.next_connection_id.store(0, SeqCst);
367 self.epoch.store(epoch, SeqCst);
368 }
369
370 pub fn teardown(&self) {
371 self.connections.write().clear();
372 }
373
374 pub fn request<T: RequestMessage>(
375 &self,
376 receiver_id: ConnectionId,
377 request: T,
378 ) -> impl Future<Output = Result<T::Response>> {
379 self.request_internal(None, receiver_id, request)
380 .map_ok(|envelope| envelope.payload)
381 }
382
383 pub fn request_envelope<T: RequestMessage>(
384 &self,
385 receiver_id: ConnectionId,
386 request: T,
387 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
388 self.request_internal(None, receiver_id, request)
389 }
390
391 pub fn forward_request<T: RequestMessage>(
392 &self,
393 sender_id: ConnectionId,
394 receiver_id: ConnectionId,
395 request: T,
396 ) -> impl Future<Output = Result<T::Response>> {
397 self.request_internal(Some(sender_id), receiver_id, request)
398 .map_ok(|envelope| envelope.payload)
399 }
400
401 pub fn request_internal<T: RequestMessage>(
402 &self,
403 original_sender_id: Option<ConnectionId>,
404 receiver_id: ConnectionId,
405 request: T,
406 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
407 let (tx, rx) = oneshot::channel();
408 let send = self.connection_state(receiver_id).and_then(|connection| {
409 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
410 connection
411 .response_channels
412 .lock()
413 .as_mut()
414 .ok_or_else(|| anyhow!("connection was closed"))?
415 .insert(message_id, tx);
416 connection
417 .outgoing_tx
418 .unbounded_send(proto::Message::Envelope(request.into_envelope(
419 message_id,
420 None,
421 original_sender_id.map(Into::into),
422 )))
423 .map_err(|_| anyhow!("connection was closed"))?;
424 Ok(())
425 });
426 async move {
427 send?;
428 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
429
430 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
431 Err(RpcError::from_proto(&error, T::NAME))
432 } else {
433 Ok(TypedEnvelope {
434 message_id: response.id,
435 sender_id: receiver_id,
436 original_sender_id: response.original_sender_id,
437 payload: T::Response::from_envelope(response)
438 .ok_or_else(|| anyhow!("received response of the wrong type"))?,
439 })
440 }
441 }
442 }
443
444 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
445 let connection = self.connection_state(receiver_id)?;
446 let message_id = connection
447 .next_message_id
448 .fetch_add(1, atomic::Ordering::SeqCst);
449 connection
450 .outgoing_tx
451 .unbounded_send(proto::Message::Envelope(
452 message.into_envelope(message_id, None, None),
453 ))?;
454 Ok(())
455 }
456
457 pub fn forward_send<T: EnvelopedMessage>(
458 &self,
459 sender_id: ConnectionId,
460 receiver_id: ConnectionId,
461 message: T,
462 ) -> Result<()> {
463 let connection = self.connection_state(receiver_id)?;
464 let message_id = connection
465 .next_message_id
466 .fetch_add(1, atomic::Ordering::SeqCst);
467 connection
468 .outgoing_tx
469 .unbounded_send(proto::Message::Envelope(message.into_envelope(
470 message_id,
471 None,
472 Some(sender_id.into()),
473 )))?;
474 Ok(())
475 }
476
477 pub fn respond<T: RequestMessage>(
478 &self,
479 receipt: Receipt<T>,
480 response: T::Response,
481 ) -> Result<()> {
482 let connection = self.connection_state(receipt.sender_id)?;
483 let message_id = connection
484 .next_message_id
485 .fetch_add(1, atomic::Ordering::SeqCst);
486 connection
487 .outgoing_tx
488 .unbounded_send(proto::Message::Envelope(response.into_envelope(
489 message_id,
490 Some(receipt.message_id),
491 None,
492 )))?;
493 Ok(())
494 }
495
496 pub fn respond_with_error<T: RequestMessage>(
497 &self,
498 receipt: Receipt<T>,
499 response: proto::Error,
500 ) -> Result<()> {
501 let connection = self.connection_state(receipt.sender_id)?;
502 let message_id = connection
503 .next_message_id
504 .fetch_add(1, atomic::Ordering::SeqCst);
505 connection
506 .outgoing_tx
507 .unbounded_send(proto::Message::Envelope(response.into_envelope(
508 message_id,
509 Some(receipt.message_id),
510 None,
511 )))?;
512 Ok(())
513 }
514
515 pub fn respond_with_unhandled_message(
516 &self,
517 envelope: Box<dyn AnyTypedEnvelope>,
518 ) -> Result<()> {
519 let connection = self.connection_state(envelope.sender_id())?;
520 let response = ErrorCode::Internal
521 .message(format!(
522 "message {} was not handled",
523 envelope.payload_type_name()
524 ))
525 .to_proto();
526 let message_id = connection
527 .next_message_id
528 .fetch_add(1, atomic::Ordering::SeqCst);
529 connection
530 .outgoing_tx
531 .unbounded_send(proto::Message::Envelope(response.into_envelope(
532 message_id,
533 Some(envelope.message_id()),
534 None,
535 )))?;
536 Ok(())
537 }
538
539 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
540 let connections = self.connections.read();
541 let connection = connections
542 .get(&connection_id)
543 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
544 Ok(connection.clone())
545 }
546}
547
548impl Serialize for Peer {
549 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
550 where
551 S: serde::Serializer,
552 {
553 let mut state = serializer.serialize_struct("Peer", 2)?;
554 state.serialize_field("connections", &*self.connections.read())?;
555 state.end()
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562 use crate::TypedEnvelope;
563 use async_tungstenite::tungstenite::Message as WebSocketMessage;
564 use gpui::TestAppContext;
565
566 fn init_logger() {
567 if std::env::var("RUST_LOG").is_ok() {
568 env_logger::init();
569 }
570 }
571
572 #[gpui::test(iterations = 50)]
573 async fn test_request_response(cx: &mut TestAppContext) {
574 init_logger();
575
576 let executor = cx.executor();
577
578 // create 2 clients connected to 1 server
579 let server = Peer::new(0);
580 let client1 = Peer::new(0);
581 let client2 = Peer::new(0);
582
583 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
584 Connection::in_memory(cx.executor());
585 let (client1_conn_id, io_task1, client1_incoming) =
586 client1.add_test_connection(client1_to_server_conn, cx.executor());
587 let (_, io_task2, server_incoming1) =
588 server.add_test_connection(server_to_client_1_conn, cx.executor());
589
590 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
591 Connection::in_memory(cx.executor());
592 let (client2_conn_id, io_task3, client2_incoming) =
593 client2.add_test_connection(client2_to_server_conn, cx.executor());
594 let (_, io_task4, server_incoming2) =
595 server.add_test_connection(server_to_client_2_conn, cx.executor());
596
597 executor.spawn(io_task1).detach();
598 executor.spawn(io_task2).detach();
599 executor.spawn(io_task3).detach();
600 executor.spawn(io_task4).detach();
601 executor
602 .spawn(handle_messages(server_incoming1, server.clone()))
603 .detach();
604 executor
605 .spawn(handle_messages(client1_incoming, client1.clone()))
606 .detach();
607 executor
608 .spawn(handle_messages(server_incoming2, server.clone()))
609 .detach();
610 executor
611 .spawn(handle_messages(client2_incoming, client2.clone()))
612 .detach();
613
614 assert_eq!(
615 client1
616 .request(client1_conn_id, proto::Ping {},)
617 .await
618 .unwrap(),
619 proto::Ack {}
620 );
621
622 assert_eq!(
623 client2
624 .request(client2_conn_id, proto::Ping {},)
625 .await
626 .unwrap(),
627 proto::Ack {}
628 );
629
630 assert_eq!(
631 client1
632 .request(client1_conn_id, proto::Test { id: 1 },)
633 .await
634 .unwrap(),
635 proto::Test { id: 1 }
636 );
637
638 assert_eq!(
639 client2
640 .request(client2_conn_id, proto::Test { id: 2 })
641 .await
642 .unwrap(),
643 proto::Test { id: 2 }
644 );
645
646 client1.disconnect(client1_conn_id);
647 client2.disconnect(client1_conn_id);
648
649 async fn handle_messages(
650 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
651 peer: Arc<Peer>,
652 ) -> Result<()> {
653 while let Some(envelope) = messages.next().await {
654 let envelope = envelope.into_any();
655 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
656 let receipt = envelope.receipt();
657 peer.respond(receipt, proto::Ack {})?
658 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
659 {
660 peer.respond(envelope.receipt(), envelope.payload.clone())?
661 } else {
662 panic!("unknown message type");
663 }
664 }
665
666 Ok(())
667 }
668 }
669
670 #[gpui::test(iterations = 50)]
671 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
672 let executor = cx.executor();
673 let server = Peer::new(0);
674 let client = Peer::new(0);
675
676 let (client_to_server_conn, server_to_client_conn, _kill) =
677 Connection::in_memory(executor.clone());
678 let (client_to_server_conn_id, io_task1, mut client_incoming) =
679 client.add_test_connection(client_to_server_conn, executor.clone());
680
681 let (server_to_client_conn_id, io_task2, mut server_incoming) =
682 server.add_test_connection(server_to_client_conn, executor.clone());
683
684 executor.spawn(io_task1).detach();
685 executor.spawn(io_task2).detach();
686
687 executor
688 .spawn(async move {
689 let future = server_incoming.next().await;
690 let request = future
691 .unwrap()
692 .into_any()
693 .downcast::<TypedEnvelope<proto::Ping>>()
694 .unwrap();
695
696 server
697 .send(
698 server_to_client_conn_id,
699 ErrorCode::Internal
700 .message("message 1".to_string())
701 .to_proto(),
702 )
703 .unwrap();
704 server
705 .send(
706 server_to_client_conn_id,
707 ErrorCode::Internal
708 .message("message 2".to_string())
709 .to_proto(),
710 )
711 .unwrap();
712 server.respond(request.receipt(), proto::Ack {}).unwrap();
713
714 // Prevent the connection from being dropped
715 server_incoming.next().await;
716 })
717 .detach();
718
719 let events = Arc::new(Mutex::new(Vec::new()));
720
721 let response = client.request(client_to_server_conn_id, proto::Ping {});
722 let response_task = executor.spawn({
723 let events = events.clone();
724 async move {
725 response.await.unwrap();
726 events.lock().push("response".to_string());
727 }
728 });
729
730 executor
731 .spawn({
732 let events = events.clone();
733 async move {
734 let incoming1 = client_incoming
735 .next()
736 .await
737 .unwrap()
738 .into_any()
739 .downcast::<TypedEnvelope<proto::Error>>()
740 .unwrap();
741 events.lock().push(incoming1.payload.message);
742 let incoming2 = client_incoming
743 .next()
744 .await
745 .unwrap()
746 .into_any()
747 .downcast::<TypedEnvelope<proto::Error>>()
748 .unwrap();
749 events.lock().push(incoming2.payload.message);
750
751 // Prevent the connection from being dropped
752 client_incoming.next().await;
753 }
754 })
755 .detach();
756
757 response_task.await;
758 assert_eq!(
759 &*events.lock(),
760 &[
761 "message 1".to_string(),
762 "message 2".to_string(),
763 "response".to_string()
764 ]
765 );
766 }
767
768 #[gpui::test(iterations = 50)]
769 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
770 let executor = cx.executor();
771 let server = Peer::new(0);
772 let client = Peer::new(0);
773
774 let (client_to_server_conn, server_to_client_conn, _kill) =
775 Connection::in_memory(cx.executor());
776 let (client_to_server_conn_id, io_task1, mut client_incoming) =
777 client.add_test_connection(client_to_server_conn, cx.executor());
778 let (server_to_client_conn_id, io_task2, mut server_incoming) =
779 server.add_test_connection(server_to_client_conn, cx.executor());
780
781 executor.spawn(io_task1).detach();
782 executor.spawn(io_task2).detach();
783
784 executor
785 .spawn(async move {
786 let request1 = server_incoming
787 .next()
788 .await
789 .unwrap()
790 .into_any()
791 .downcast::<TypedEnvelope<proto::Ping>>()
792 .unwrap();
793 let request2 = server_incoming
794 .next()
795 .await
796 .unwrap()
797 .into_any()
798 .downcast::<TypedEnvelope<proto::Ping>>()
799 .unwrap();
800
801 server
802 .send(
803 server_to_client_conn_id,
804 ErrorCode::Internal
805 .message("message 1".to_string())
806 .to_proto(),
807 )
808 .unwrap();
809 server
810 .send(
811 server_to_client_conn_id,
812 ErrorCode::Internal
813 .message("message 2".to_string())
814 .to_proto(),
815 )
816 .unwrap();
817 server.respond(request1.receipt(), proto::Ack {}).unwrap();
818 server.respond(request2.receipt(), proto::Ack {}).unwrap();
819
820 // Prevent the connection from being dropped
821 server_incoming.next().await;
822 })
823 .detach();
824
825 let events = Arc::new(Mutex::new(Vec::new()));
826
827 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
828 let request1_task = executor.spawn(request1);
829 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
830 let request2_task = executor.spawn({
831 let events = events.clone();
832 async move {
833 request2.await.unwrap();
834 events.lock().push("response 2".to_string());
835 }
836 });
837
838 executor
839 .spawn({
840 let events = events.clone();
841 async move {
842 let incoming1 = client_incoming
843 .next()
844 .await
845 .unwrap()
846 .into_any()
847 .downcast::<TypedEnvelope<proto::Error>>()
848 .unwrap();
849 events.lock().push(incoming1.payload.message);
850 let incoming2 = client_incoming
851 .next()
852 .await
853 .unwrap()
854 .into_any()
855 .downcast::<TypedEnvelope<proto::Error>>()
856 .unwrap();
857 events.lock().push(incoming2.payload.message);
858
859 // Prevent the connection from being dropped
860 client_incoming.next().await;
861 }
862 })
863 .detach();
864
865 // Allow the request to make some progress before dropping it.
866 cx.executor().simulate_random_delay().await;
867 drop(request1_task);
868
869 request2_task.await;
870 assert_eq!(
871 &*events.lock(),
872 &[
873 "message 1".to_string(),
874 "message 2".to_string(),
875 "response 2".to_string()
876 ]
877 );
878 }
879
880 #[gpui::test(iterations = 50)]
881 async fn test_disconnect(cx: &mut TestAppContext) {
882 let executor = cx.executor();
883
884 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
885
886 let client = Peer::new(0);
887 let (connection_id, io_handler, mut incoming) =
888 client.add_test_connection(client_conn, executor.clone());
889
890 let (io_ended_tx, io_ended_rx) = oneshot::channel();
891 executor
892 .spawn(async move {
893 io_handler.await.ok();
894 io_ended_tx.send(()).unwrap();
895 })
896 .detach();
897
898 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
899 executor
900 .spawn(async move {
901 incoming.next().await;
902 messages_ended_tx.send(()).unwrap();
903 })
904 .detach();
905
906 client.disconnect(connection_id);
907
908 let _ = io_ended_rx.await;
909 let _ = messages_ended_rx.await;
910 assert!(server_conn
911 .send(WebSocketMessage::Binary(vec![]))
912 .await
913 .is_err());
914 }
915
916 #[gpui::test(iterations = 50)]
917 async fn test_io_error(cx: &mut TestAppContext) {
918 let executor = cx.executor();
919 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
920
921 let client = Peer::new(0);
922 let (connection_id, io_handler, mut incoming) =
923 client.add_test_connection(client_conn, executor.clone());
924 executor.spawn(io_handler).detach();
925 executor
926 .spawn(async move { incoming.next().await })
927 .detach();
928
929 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
930 let _request = server_conn.rx.next().await.unwrap().unwrap();
931
932 drop(server_conn);
933 assert_eq!(
934 response.await.unwrap_err().to_string(),
935 "connection was closed"
936 );
937 }
938}