1use crate::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
2
3use super::{
4 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
5 Connection,
6};
7use anyhow::{anyhow, Context, Result};
8use collections::HashMap;
9use futures::{
10 channel::{mpsc, oneshot},
11 stream::BoxStream,
12 FutureExt, SinkExt, StreamExt, TryFutureExt,
13};
14use parking_lot::{Mutex, RwLock};
15use serde::{ser::SerializeStruct, Serialize};
16use std::{fmt, sync::atomic::Ordering::SeqCst};
17use std::{
18 future::Future,
19 marker::PhantomData,
20 sync::{
21 atomic::{self, AtomicU32},
22 Arc,
23 },
24 time::Duration,
25};
26use tracing::instrument;
27
28#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
29pub struct ConnectionId {
30 pub owner_id: u32,
31 pub id: u32,
32}
33
34impl Into<PeerId> for ConnectionId {
35 fn into(self) -> PeerId {
36 PeerId {
37 owner_id: self.owner_id,
38 id: self.id,
39 }
40 }
41}
42
43impl From<PeerId> for ConnectionId {
44 fn from(peer_id: PeerId) -> Self {
45 Self {
46 owner_id: peer_id.owner_id,
47 id: peer_id.id,
48 }
49 }
50}
51
52impl fmt::Display for ConnectionId {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 write!(f, "{}/{}", self.owner_id, self.id)
55 }
56}
57
58pub struct Receipt<T> {
59 pub sender_id: ConnectionId,
60 pub message_id: u32,
61 payload_type: PhantomData<T>,
62}
63
64impl<T> Clone for Receipt<T> {
65 fn clone(&self) -> Self {
66 Self {
67 sender_id: self.sender_id,
68 message_id: self.message_id,
69 payload_type: PhantomData,
70 }
71 }
72}
73
74impl<T> Copy for Receipt<T> {}
75
76#[derive(Clone, Debug)]
77pub struct TypedEnvelope<T> {
78 pub sender_id: ConnectionId,
79 pub original_sender_id: Option<PeerId>,
80 pub message_id: u32,
81 pub payload: T,
82}
83
84impl<T> TypedEnvelope<T> {
85 pub fn original_sender_id(&self) -> Result<PeerId> {
86 self.original_sender_id
87 .ok_or_else(|| anyhow!("missing original_sender_id"))
88 }
89}
90
91impl<T: RequestMessage> TypedEnvelope<T> {
92 pub fn receipt(&self) -> Receipt<T> {
93 Receipt {
94 sender_id: self.sender_id,
95 message_id: self.message_id,
96 payload_type: PhantomData,
97 }
98 }
99}
100
101pub struct Peer {
102 epoch: AtomicU32,
103 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
104 next_connection_id: AtomicU32,
105}
106
107#[derive(Clone, Serialize)]
108pub struct ConnectionState {
109 #[serde(skip)]
110 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
111 next_message_id: Arc<AtomicU32>,
112 #[allow(clippy::type_complexity)]
113 #[serde(skip)]
114 response_channels:
115 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
116}
117
118const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
119const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
120pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
121
122impl Peer {
123 pub fn new(epoch: u32) -> Arc<Self> {
124 Arc::new(Self {
125 epoch: AtomicU32::new(epoch),
126 connections: Default::default(),
127 next_connection_id: Default::default(),
128 })
129 }
130
131 pub fn epoch(&self) -> u32 {
132 self.epoch.load(SeqCst)
133 }
134
135 #[instrument(skip_all)]
136 pub fn add_connection<F, Fut, Out>(
137 self: &Arc<Self>,
138 connection: Connection,
139 create_timer: F,
140 ) -> (
141 ConnectionId,
142 impl Future<Output = anyhow::Result<()>> + Send,
143 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
144 )
145 where
146 F: Send + Fn(Duration) -> Fut,
147 Fut: Send + Future<Output = Out>,
148 Out: Send,
149 {
150 // For outgoing messages, use an unbounded channel so that application code
151 // can always send messages without yielding. For incoming messages, use a
152 // bounded channel so that other peers will receive backpressure if they send
153 // messages faster than this peer can process them.
154 #[cfg(any(test, feature = "test-support"))]
155 const INCOMING_BUFFER_SIZE: usize = 1;
156 #[cfg(not(any(test, feature = "test-support")))]
157 const INCOMING_BUFFER_SIZE: usize = 64;
158 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
159 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
160
161 let connection_id = ConnectionId {
162 owner_id: self.epoch.load(SeqCst),
163 id: self.next_connection_id.fetch_add(1, SeqCst),
164 };
165 let connection_state = ConnectionState {
166 outgoing_tx,
167 next_message_id: Default::default(),
168 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
169 };
170 let mut writer = MessageStream::new(connection.tx);
171 let mut reader = MessageStream::new(connection.rx);
172
173 let this = self.clone();
174 let response_channels = connection_state.response_channels.clone();
175 let handle_io = async move {
176 tracing::trace!(%connection_id, "handle io future: start");
177
178 let _end_connection = util::defer(|| {
179 response_channels.lock().take();
180 this.connections.write().remove(&connection_id);
181 tracing::trace!(%connection_id, "handle io future: end");
182 });
183
184 // Send messages on this frequency so the connection isn't closed.
185 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
186 futures::pin_mut!(keepalive_timer);
187
188 // Disconnect if we don't receive messages at least this frequently.
189 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
190 futures::pin_mut!(receive_timeout);
191
192 loop {
193 tracing::trace!(%connection_id, "outer loop iteration start");
194 let read_message = reader.read().fuse();
195 futures::pin_mut!(read_message);
196
197 loop {
198 tracing::trace!(%connection_id, "inner loop iteration start");
199 futures::select_biased! {
200 outgoing = outgoing_rx.next().fuse() => match outgoing {
201 Some(outgoing) => {
202 tracing::trace!(%connection_id, "outgoing rpc message: writing");
203 futures::select_biased! {
204 result = writer.write(outgoing).fuse() => {
205 tracing::trace!(%connection_id, "outgoing rpc message: done writing");
206 result.context("failed to write RPC message")?;
207 tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
208 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
209 }
210 _ = create_timer(WRITE_TIMEOUT).fuse() => {
211 tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
212 Err(anyhow!("timed out writing message"))?;
213 }
214 }
215 }
216 None => {
217 tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
218 return Ok(())
219 },
220 },
221 _ = keepalive_timer => {
222 tracing::trace!(%connection_id, "keepalive interval: pinging");
223 futures::select_biased! {
224 result = writer.write(proto::Message::Ping).fuse() => {
225 tracing::trace!(%connection_id, "keepalive interval: done pinging");
226 result.context("failed to send keepalive")?;
227 tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
228 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
229 }
230 _ = create_timer(WRITE_TIMEOUT).fuse() => {
231 tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
232 Err(anyhow!("timed out sending keepalive"))?;
233 }
234 }
235 }
236 incoming = read_message => {
237 let incoming = incoming.context("error reading rpc message from socket")?;
238 tracing::trace!(%connection_id, "incoming rpc message: received");
239 tracing::trace!(%connection_id, "receive timeout: resetting");
240 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
241 if let proto::Message::Envelope(incoming) = incoming {
242 tracing::trace!(%connection_id, "incoming rpc message: processing");
243 futures::select_biased! {
244 result = incoming_tx.send(incoming).fuse() => match result {
245 Ok(_) => {
246 tracing::trace!(%connection_id, "incoming rpc message: processed");
247 }
248 Err(_) => {
249 tracing::trace!(%connection_id, "incoming rpc message: channel closed");
250 return Ok(())
251 }
252 },
253 _ = create_timer(WRITE_TIMEOUT).fuse() => {
254 tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
255 Err(anyhow!("timed out processing incoming message"))?
256 }
257 }
258 }
259 break;
260 },
261 _ = receive_timeout => {
262 tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
263 Err(anyhow!("delay between messages too long"))?
264 }
265 }
266 }
267 }
268 };
269
270 let response_channels = connection_state.response_channels.clone();
271 self.connections
272 .write()
273 .insert(connection_id, connection_state);
274
275 let incoming_rx = incoming_rx.filter_map(move |incoming| {
276 let response_channels = response_channels.clone();
277 async move {
278 let message_id = incoming.id;
279 tracing::trace!(?incoming, "incoming message future: start");
280 let _end = util::defer(move || {
281 tracing::trace!(%connection_id, message_id, "incoming message future: end");
282 });
283
284 if let Some(responding_to) = incoming.responding_to {
285 tracing::trace!(
286 %connection_id,
287 message_id,
288 responding_to,
289 "incoming response: received"
290 );
291 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
292 if let Some(tx) = channel {
293 let requester_resumed = oneshot::channel();
294 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
295 tracing::trace!(
296 %connection_id,
297 message_id,
298 responding_to = responding_to,
299 ?error,
300 "incoming response: request future dropped",
301 );
302 }
303
304 tracing::trace!(
305 %connection_id,
306 message_id,
307 responding_to,
308 "incoming response: waiting to resume requester"
309 );
310 let _ = requester_resumed.1.await;
311 tracing::trace!(
312 %connection_id,
313 message_id,
314 responding_to,
315 "incoming response: requester resumed"
316 );
317 } else {
318 tracing::warn!(
319 %connection_id,
320 message_id,
321 responding_to,
322 "incoming response: unknown request"
323 );
324 }
325
326 None
327 } else {
328 tracing::trace!(%connection_id, message_id, "incoming message: received");
329 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
330 tracing::error!(
331 %connection_id,
332 message_id,
333 "unable to construct a typed envelope"
334 );
335 None
336 })
337 }
338 }
339 });
340 (connection_id, handle_io, incoming_rx.boxed())
341 }
342
343 #[cfg(any(test, feature = "test-support"))]
344 pub fn add_test_connection(
345 self: &Arc<Self>,
346 connection: Connection,
347 executor: gpui::BackgroundExecutor,
348 ) -> (
349 ConnectionId,
350 impl Future<Output = anyhow::Result<()>> + Send,
351 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
352 ) {
353 let executor = executor.clone();
354 self.add_connection(connection, move |duration| executor.timer(duration))
355 }
356
357 pub fn disconnect(&self, connection_id: ConnectionId) {
358 self.connections.write().remove(&connection_id);
359 }
360
361 pub fn reset(&self, epoch: u32) {
362 self.teardown();
363 self.next_connection_id.store(0, SeqCst);
364 self.epoch.store(epoch, SeqCst);
365 }
366
367 pub fn teardown(&self) {
368 self.connections.write().clear();
369 }
370
371 pub fn request<T: RequestMessage>(
372 &self,
373 receiver_id: ConnectionId,
374 request: T,
375 ) -> impl Future<Output = Result<T::Response>> {
376 self.request_internal(None, receiver_id, request)
377 .map_ok(|envelope| envelope.payload)
378 }
379
380 pub fn request_envelope<T: RequestMessage>(
381 &self,
382 receiver_id: ConnectionId,
383 request: T,
384 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
385 self.request_internal(None, receiver_id, request)
386 }
387
388 pub fn forward_request<T: RequestMessage>(
389 &self,
390 sender_id: ConnectionId,
391 receiver_id: ConnectionId,
392 request: T,
393 ) -> impl Future<Output = Result<T::Response>> {
394 self.request_internal(Some(sender_id), receiver_id, request)
395 .map_ok(|envelope| envelope.payload)
396 }
397
398 pub fn request_internal<T: RequestMessage>(
399 &self,
400 original_sender_id: Option<ConnectionId>,
401 receiver_id: ConnectionId,
402 request: T,
403 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
404 let (tx, rx) = oneshot::channel();
405 let send = self.connection_state(receiver_id).and_then(|connection| {
406 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
407 connection
408 .response_channels
409 .lock()
410 .as_mut()
411 .ok_or_else(|| anyhow!("connection was closed"))?
412 .insert(message_id, tx);
413 connection
414 .outgoing_tx
415 .unbounded_send(proto::Message::Envelope(request.into_envelope(
416 message_id,
417 None,
418 original_sender_id.map(Into::into),
419 )))
420 .map_err(|_| anyhow!("connection was closed"))?;
421 Ok(())
422 });
423 async move {
424 send?;
425 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
426
427 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
428 Err(RpcError::from_proto(&error, T::NAME))
429 } else {
430 Ok(TypedEnvelope {
431 message_id: response.id,
432 sender_id: receiver_id,
433 original_sender_id: response.original_sender_id,
434 payload: T::Response::from_envelope(response)
435 .ok_or_else(|| anyhow!("received response of the wrong type"))?,
436 })
437 }
438 }
439 }
440
441 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
442 let connection = self.connection_state(receiver_id)?;
443 let message_id = connection
444 .next_message_id
445 .fetch_add(1, atomic::Ordering::SeqCst);
446 connection
447 .outgoing_tx
448 .unbounded_send(proto::Message::Envelope(
449 message.into_envelope(message_id, None, None),
450 ))?;
451 Ok(())
452 }
453
454 pub fn forward_send<T: EnvelopedMessage>(
455 &self,
456 sender_id: ConnectionId,
457 receiver_id: ConnectionId,
458 message: T,
459 ) -> Result<()> {
460 let connection = self.connection_state(receiver_id)?;
461 let message_id = connection
462 .next_message_id
463 .fetch_add(1, atomic::Ordering::SeqCst);
464 connection
465 .outgoing_tx
466 .unbounded_send(proto::Message::Envelope(message.into_envelope(
467 message_id,
468 None,
469 Some(sender_id.into()),
470 )))?;
471 Ok(())
472 }
473
474 pub fn respond<T: RequestMessage>(
475 &self,
476 receipt: Receipt<T>,
477 response: T::Response,
478 ) -> Result<()> {
479 let connection = self.connection_state(receipt.sender_id)?;
480 let message_id = connection
481 .next_message_id
482 .fetch_add(1, atomic::Ordering::SeqCst);
483 connection
484 .outgoing_tx
485 .unbounded_send(proto::Message::Envelope(response.into_envelope(
486 message_id,
487 Some(receipt.message_id),
488 None,
489 )))?;
490 Ok(())
491 }
492
493 pub fn respond_with_error<T: RequestMessage>(
494 &self,
495 receipt: Receipt<T>,
496 response: proto::Error,
497 ) -> Result<()> {
498 let connection = self.connection_state(receipt.sender_id)?;
499 let message_id = connection
500 .next_message_id
501 .fetch_add(1, atomic::Ordering::SeqCst);
502 connection
503 .outgoing_tx
504 .unbounded_send(proto::Message::Envelope(response.into_envelope(
505 message_id,
506 Some(receipt.message_id),
507 None,
508 )))?;
509 Ok(())
510 }
511
512 pub fn respond_with_unhandled_message(
513 &self,
514 envelope: Box<dyn AnyTypedEnvelope>,
515 ) -> Result<()> {
516 let connection = self.connection_state(envelope.sender_id())?;
517 let response = ErrorCode::Internal
518 .message(format!(
519 "message {} was not handled",
520 envelope.payload_type_name()
521 ))
522 .to_proto();
523 let message_id = connection
524 .next_message_id
525 .fetch_add(1, atomic::Ordering::SeqCst);
526 connection
527 .outgoing_tx
528 .unbounded_send(proto::Message::Envelope(response.into_envelope(
529 message_id,
530 Some(envelope.message_id()),
531 None,
532 )))?;
533 Ok(())
534 }
535
536 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
537 let connections = self.connections.read();
538 let connection = connections
539 .get(&connection_id)
540 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
541 Ok(connection.clone())
542 }
543}
544
545impl Serialize for Peer {
546 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
547 where
548 S: serde::Serializer,
549 {
550 let mut state = serializer.serialize_struct("Peer", 2)?;
551 state.serialize_field("connections", &*self.connections.read())?;
552 state.end()
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use crate::TypedEnvelope;
560 use async_tungstenite::tungstenite::Message as WebSocketMessage;
561 use gpui::TestAppContext;
562
563 fn init_logger() {
564 if std::env::var("RUST_LOG").is_ok() {
565 env_logger::init();
566 }
567 }
568
569 #[gpui::test(iterations = 50)]
570 async fn test_request_response(cx: &mut TestAppContext) {
571 init_logger();
572
573 let executor = cx.executor();
574
575 // create 2 clients connected to 1 server
576 let server = Peer::new(0);
577 let client1 = Peer::new(0);
578 let client2 = Peer::new(0);
579
580 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
581 Connection::in_memory(cx.executor());
582 let (client1_conn_id, io_task1, client1_incoming) =
583 client1.add_test_connection(client1_to_server_conn, cx.executor());
584 let (_, io_task2, server_incoming1) =
585 server.add_test_connection(server_to_client_1_conn, cx.executor());
586
587 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
588 Connection::in_memory(cx.executor());
589 let (client2_conn_id, io_task3, client2_incoming) =
590 client2.add_test_connection(client2_to_server_conn, cx.executor());
591 let (_, io_task4, server_incoming2) =
592 server.add_test_connection(server_to_client_2_conn, cx.executor());
593
594 executor.spawn(io_task1).detach();
595 executor.spawn(io_task2).detach();
596 executor.spawn(io_task3).detach();
597 executor.spawn(io_task4).detach();
598 executor
599 .spawn(handle_messages(server_incoming1, server.clone()))
600 .detach();
601 executor
602 .spawn(handle_messages(client1_incoming, client1.clone()))
603 .detach();
604 executor
605 .spawn(handle_messages(server_incoming2, server.clone()))
606 .detach();
607 executor
608 .spawn(handle_messages(client2_incoming, client2.clone()))
609 .detach();
610
611 assert_eq!(
612 client1
613 .request(client1_conn_id, proto::Ping {},)
614 .await
615 .unwrap(),
616 proto::Ack {}
617 );
618
619 assert_eq!(
620 client2
621 .request(client2_conn_id, proto::Ping {},)
622 .await
623 .unwrap(),
624 proto::Ack {}
625 );
626
627 assert_eq!(
628 client1
629 .request(client1_conn_id, proto::Test { id: 1 },)
630 .await
631 .unwrap(),
632 proto::Test { id: 1 }
633 );
634
635 assert_eq!(
636 client2
637 .request(client2_conn_id, proto::Test { id: 2 })
638 .await
639 .unwrap(),
640 proto::Test { id: 2 }
641 );
642
643 client1.disconnect(client1_conn_id);
644 client2.disconnect(client1_conn_id);
645
646 async fn handle_messages(
647 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
648 peer: Arc<Peer>,
649 ) -> Result<()> {
650 while let Some(envelope) = messages.next().await {
651 let envelope = envelope.into_any();
652 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
653 let receipt = envelope.receipt();
654 peer.respond(receipt, proto::Ack {})?
655 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
656 {
657 peer.respond(envelope.receipt(), envelope.payload.clone())?
658 } else {
659 panic!("unknown message type");
660 }
661 }
662
663 Ok(())
664 }
665 }
666
667 #[gpui::test(iterations = 50)]
668 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
669 let executor = cx.executor();
670 let server = Peer::new(0);
671 let client = Peer::new(0);
672
673 let (client_to_server_conn, server_to_client_conn, _kill) =
674 Connection::in_memory(executor.clone());
675 let (client_to_server_conn_id, io_task1, mut client_incoming) =
676 client.add_test_connection(client_to_server_conn, executor.clone());
677
678 let (server_to_client_conn_id, io_task2, mut server_incoming) =
679 server.add_test_connection(server_to_client_conn, executor.clone());
680
681 executor.spawn(io_task1).detach();
682 executor.spawn(io_task2).detach();
683
684 executor
685 .spawn(async move {
686 let future = server_incoming.next().await;
687 let request = future
688 .unwrap()
689 .into_any()
690 .downcast::<TypedEnvelope<proto::Ping>>()
691 .unwrap();
692
693 server
694 .send(
695 server_to_client_conn_id,
696 ErrorCode::Internal
697 .message("message 1".to_string())
698 .to_proto(),
699 )
700 .unwrap();
701 server
702 .send(
703 server_to_client_conn_id,
704 ErrorCode::Internal
705 .message("message 2".to_string())
706 .to_proto(),
707 )
708 .unwrap();
709 server.respond(request.receipt(), proto::Ack {}).unwrap();
710
711 // Prevent the connection from being dropped
712 server_incoming.next().await;
713 })
714 .detach();
715
716 let events = Arc::new(Mutex::new(Vec::new()));
717
718 let response = client.request(client_to_server_conn_id, proto::Ping {});
719 let response_task = executor.spawn({
720 let events = events.clone();
721 async move {
722 response.await.unwrap();
723 events.lock().push("response".to_string());
724 }
725 });
726
727 executor
728 .spawn({
729 let events = events.clone();
730 async move {
731 let incoming1 = client_incoming
732 .next()
733 .await
734 .unwrap()
735 .into_any()
736 .downcast::<TypedEnvelope<proto::Error>>()
737 .unwrap();
738 events.lock().push(incoming1.payload.message);
739 let incoming2 = client_incoming
740 .next()
741 .await
742 .unwrap()
743 .into_any()
744 .downcast::<TypedEnvelope<proto::Error>>()
745 .unwrap();
746 events.lock().push(incoming2.payload.message);
747
748 // Prevent the connection from being dropped
749 client_incoming.next().await;
750 }
751 })
752 .detach();
753
754 response_task.await;
755 assert_eq!(
756 &*events.lock(),
757 &[
758 "message 1".to_string(),
759 "message 2".to_string(),
760 "response".to_string()
761 ]
762 );
763 }
764
765 #[gpui::test(iterations = 50)]
766 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
767 let executor = cx.executor();
768 let server = Peer::new(0);
769 let client = Peer::new(0);
770
771 let (client_to_server_conn, server_to_client_conn, _kill) =
772 Connection::in_memory(cx.executor());
773 let (client_to_server_conn_id, io_task1, mut client_incoming) =
774 client.add_test_connection(client_to_server_conn, cx.executor());
775 let (server_to_client_conn_id, io_task2, mut server_incoming) =
776 server.add_test_connection(server_to_client_conn, cx.executor());
777
778 executor.spawn(io_task1).detach();
779 executor.spawn(io_task2).detach();
780
781 executor
782 .spawn(async move {
783 let request1 = server_incoming
784 .next()
785 .await
786 .unwrap()
787 .into_any()
788 .downcast::<TypedEnvelope<proto::Ping>>()
789 .unwrap();
790 let request2 = server_incoming
791 .next()
792 .await
793 .unwrap()
794 .into_any()
795 .downcast::<TypedEnvelope<proto::Ping>>()
796 .unwrap();
797
798 server
799 .send(
800 server_to_client_conn_id,
801 ErrorCode::Internal
802 .message("message 1".to_string())
803 .to_proto(),
804 )
805 .unwrap();
806 server
807 .send(
808 server_to_client_conn_id,
809 ErrorCode::Internal
810 .message("message 2".to_string())
811 .to_proto(),
812 )
813 .unwrap();
814 server.respond(request1.receipt(), proto::Ack {}).unwrap();
815 server.respond(request2.receipt(), proto::Ack {}).unwrap();
816
817 // Prevent the connection from being dropped
818 server_incoming.next().await;
819 })
820 .detach();
821
822 let events = Arc::new(Mutex::new(Vec::new()));
823
824 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
825 let request1_task = executor.spawn(request1);
826 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
827 let request2_task = executor.spawn({
828 let events = events.clone();
829 async move {
830 request2.await.unwrap();
831 events.lock().push("response 2".to_string());
832 }
833 });
834
835 executor
836 .spawn({
837 let events = events.clone();
838 async move {
839 let incoming1 = client_incoming
840 .next()
841 .await
842 .unwrap()
843 .into_any()
844 .downcast::<TypedEnvelope<proto::Error>>()
845 .unwrap();
846 events.lock().push(incoming1.payload.message);
847 let incoming2 = client_incoming
848 .next()
849 .await
850 .unwrap()
851 .into_any()
852 .downcast::<TypedEnvelope<proto::Error>>()
853 .unwrap();
854 events.lock().push(incoming2.payload.message);
855
856 // Prevent the connection from being dropped
857 client_incoming.next().await;
858 }
859 })
860 .detach();
861
862 // Allow the request to make some progress before dropping it.
863 cx.executor().simulate_random_delay().await;
864 drop(request1_task);
865
866 request2_task.await;
867 assert_eq!(
868 &*events.lock(),
869 &[
870 "message 1".to_string(),
871 "message 2".to_string(),
872 "response 2".to_string()
873 ]
874 );
875 }
876
877 #[gpui::test(iterations = 50)]
878 async fn test_disconnect(cx: &mut TestAppContext) {
879 let executor = cx.executor();
880
881 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
882
883 let client = Peer::new(0);
884 let (connection_id, io_handler, mut incoming) =
885 client.add_test_connection(client_conn, executor.clone());
886
887 let (io_ended_tx, io_ended_rx) = oneshot::channel();
888 executor
889 .spawn(async move {
890 io_handler.await.ok();
891 io_ended_tx.send(()).unwrap();
892 })
893 .detach();
894
895 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
896 executor
897 .spawn(async move {
898 incoming.next().await;
899 messages_ended_tx.send(()).unwrap();
900 })
901 .detach();
902
903 client.disconnect(connection_id);
904
905 let _ = io_ended_rx.await;
906 let _ = messages_ended_rx.await;
907 assert!(server_conn
908 .send(WebSocketMessage::Binary(vec![]))
909 .await
910 .is_err());
911 }
912
913 #[gpui::test(iterations = 50)]
914 async fn test_io_error(cx: &mut TestAppContext) {
915 let executor = cx.executor();
916 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
917
918 let client = Peer::new(0);
919 let (connection_id, io_handler, mut incoming) =
920 client.add_test_connection(client_conn, executor.clone());
921 executor.spawn(io_handler).detach();
922 executor
923 .spawn(async move { incoming.next().await })
924 .detach();
925
926 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
927 let _request = server_conn.rx.next().await.unwrap().unwrap();
928
929 drop(server_conn);
930 assert_eq!(
931 response.await.unwrap_err().to_string(),
932 "connection was closed"
933 );
934 }
935}