1use super::{
2 Connection,
3 message_stream::{Message, MessageStream},
4 proto::{
5 self, AnyTypedEnvelope, EnvelopedMessage, PeerId, Receipt, RequestMessage, TypedEnvelope,
6 },
7};
8use anyhow::{Context as _, Result, anyhow};
9use collections::HashMap;
10use futures::{
11 FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
12 channel::{mpsc, oneshot},
13 stream::BoxStream,
14};
15use parking_lot::{Mutex, RwLock};
16use proto::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
17use serde::{Serialize, ser::SerializeStruct};
18use std::{
19 fmt, future,
20 future::Future,
21 sync::atomic::Ordering::SeqCst,
22 sync::{
23 Arc,
24 atomic::{self, AtomicU32},
25 },
26 time::Duration,
27 time::Instant,
28};
29use tracing::instrument;
30
31#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
32pub struct ConnectionId {
33 pub owner_id: u32,
34 pub id: u32,
35}
36
37impl From<ConnectionId> for PeerId {
38 fn from(id: ConnectionId) -> Self {
39 PeerId {
40 owner_id: id.owner_id,
41 id: id.id,
42 }
43 }
44}
45
46impl From<PeerId> for ConnectionId {
47 fn from(peer_id: PeerId) -> Self {
48 Self {
49 owner_id: peer_id.owner_id,
50 id: peer_id.id,
51 }
52 }
53}
54
55impl fmt::Display for ConnectionId {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 write!(f, "{}/{}", self.owner_id, self.id)
58 }
59}
60
61pub struct Peer {
62 epoch: AtomicU32,
63 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
64 next_connection_id: AtomicU32,
65}
66
67#[derive(Clone, Serialize)]
68pub struct ConnectionState {
69 #[serde(skip)]
70 outgoing_tx: mpsc::UnboundedSender<Message>,
71 next_message_id: Arc<AtomicU32>,
72 #[allow(clippy::type_complexity)]
73 #[serde(skip)]
74 response_channels: Arc<
75 Mutex<
76 Option<
77 HashMap<
78 u32,
79 oneshot::Sender<(proto::Envelope, std::time::Instant, oneshot::Sender<()>)>,
80 >,
81 >,
82 >,
83 >,
84 #[allow(clippy::type_complexity)]
85 #[serde(skip)]
86 stream_response_channels: Arc<
87 Mutex<
88 Option<
89 HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
90 >,
91 >,
92 >,
93}
94
95const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
96const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
97pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
98
99impl Peer {
100 pub fn new(epoch: u32) -> Arc<Self> {
101 Arc::new(Self {
102 epoch: AtomicU32::new(epoch),
103 connections: Default::default(),
104 next_connection_id: Default::default(),
105 })
106 }
107
108 pub fn epoch(&self) -> u32 {
109 self.epoch.load(SeqCst)
110 }
111
112 #[instrument(skip_all)]
113 pub fn add_connection<F, Fut, Out>(
114 self: &Arc<Self>,
115 connection: Connection,
116 create_timer: F,
117 ) -> (
118 ConnectionId,
119 impl Future<Output = anyhow::Result<()>> + Send + use<F, Fut, Out>,
120 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
121 )
122 where
123 F: Send + Fn(Duration) -> Fut,
124 Fut: Send + Future<Output = Out>,
125 Out: Send,
126 {
127 // For outgoing messages, use an unbounded channel so that application code
128 // can always send messages without yielding. For incoming messages, use a
129 // bounded channel so that other peers will receive backpressure if they send
130 // messages faster than this peer can process them.
131 #[cfg(any(test, feature = "test-support"))]
132 const INCOMING_BUFFER_SIZE: usize = 1;
133 #[cfg(not(any(test, feature = "test-support")))]
134 const INCOMING_BUFFER_SIZE: usize = 256;
135 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
136 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
137
138 let connection_id = ConnectionId {
139 owner_id: self.epoch.load(SeqCst),
140 id: self.next_connection_id.fetch_add(1, SeqCst),
141 };
142 let connection_state = ConnectionState {
143 outgoing_tx,
144 next_message_id: Default::default(),
145 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
146 stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
147 };
148 let mut writer = MessageStream::new(connection.tx);
149 let mut reader = MessageStream::new(connection.rx);
150
151 let this = self.clone();
152 let response_channels = connection_state.response_channels.clone();
153 let stream_response_channels = connection_state.stream_response_channels.clone();
154
155 let handle_io = async move {
156 tracing::trace!(%connection_id, "handle io future: start");
157
158 let _end_connection = util::defer(|| {
159 response_channels.lock().take();
160 if let Some(channels) = stream_response_channels.lock().take() {
161 for channel in channels.values() {
162 let _ = channel.unbounded_send((
163 Err(anyhow!("connection closed")),
164 oneshot::channel().0,
165 ));
166 }
167 }
168 this.connections.write().remove(&connection_id);
169 tracing::trace!(%connection_id, "handle io future: end");
170 });
171
172 // Send messages on this frequency so the connection isn't closed.
173 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
174 futures::pin_mut!(keepalive_timer);
175
176 // Disconnect if we don't receive messages at least this frequently.
177 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
178 futures::pin_mut!(receive_timeout);
179
180 loop {
181 tracing::trace!(%connection_id, "outer loop iteration start");
182 let read_message = reader.read().fuse();
183 futures::pin_mut!(read_message);
184
185 loop {
186 tracing::trace!(%connection_id, "inner loop iteration start");
187 futures::select_biased! {
188 outgoing = outgoing_rx.next().fuse() => match outgoing {
189 Some(outgoing) => {
190 tracing::trace!(%connection_id, "outgoing rpc message: writing");
191 futures::select_biased! {
192 result = writer.write(outgoing).fuse() => {
193 tracing::trace!(%connection_id, "outgoing rpc message: done writing");
194 result.context("failed to write RPC message")?;
195 tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
196 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
197 }
198 _ = create_timer(WRITE_TIMEOUT).fuse() => {
199 tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
200 anyhow::bail!("timed out writing message");
201 }
202 }
203 }
204 None => {
205 tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
206 return Ok(())
207 },
208 },
209 _ = keepalive_timer => {
210 tracing::trace!(%connection_id, "keepalive interval: pinging");
211 futures::select_biased! {
212 result = writer.write(Message::Ping).fuse() => {
213 tracing::trace!(%connection_id, "keepalive interval: done pinging");
214 result.context("failed to send keepalive")?;
215 tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
216 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
217 }
218 _ = create_timer(WRITE_TIMEOUT).fuse() => {
219 tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
220 anyhow::bail!("timed out sending keepalive");
221 }
222 }
223 }
224 incoming = read_message => {
225 let incoming = incoming.context("error reading rpc message from socket")?;
226 tracing::trace!(%connection_id, "incoming rpc message: received");
227 tracing::trace!(%connection_id, "receive timeout: resetting");
228 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
229 if let (Message::Envelope(incoming), received_at) = incoming {
230 tracing::trace!(%connection_id, "incoming rpc message: processing");
231 futures::select_biased! {
232 result = incoming_tx.send((incoming, received_at)).fuse() => match result {
233 Ok(_) => {
234 tracing::trace!(%connection_id, "incoming rpc message: processed");
235 }
236 Err(_) => {
237 tracing::trace!(%connection_id, "incoming rpc message: channel closed");
238 return Ok(())
239 }
240 },
241 _ = create_timer(WRITE_TIMEOUT).fuse() => {
242 tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
243 anyhow::bail!("timed out processing incoming message");
244 }
245 }
246 }
247 break;
248 },
249 _ = receive_timeout => {
250 tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
251 anyhow::bail!("delay between messages too long");
252 }
253 }
254 }
255 }
256 };
257
258 let response_channels = connection_state.response_channels.clone();
259 let stream_response_channels = connection_state.stream_response_channels.clone();
260 self.connections
261 .write()
262 .insert(connection_id, connection_state);
263
264 let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
265 let response_channels = response_channels.clone();
266 let stream_response_channels = stream_response_channels.clone();
267 async move {
268 let message_id = incoming.id;
269 tracing::trace!(?incoming, "incoming message future: start");
270 let _end = util::defer(move || {
271 tracing::trace!(%connection_id, message_id, "incoming message future: end");
272 });
273
274 if let Some(responding_to) = incoming.responding_to {
275 tracing::trace!(
276 %connection_id,
277 message_id,
278 responding_to,
279 "incoming response: received"
280 );
281 let response_channel =
282 response_channels.lock().as_mut()?.remove(&responding_to);
283 let stream_response_channel = stream_response_channels
284 .lock()
285 .as_ref()?
286 .get(&responding_to)
287 .cloned();
288
289 if let Some(tx) = response_channel {
290 let requester_resumed = oneshot::channel();
291 if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
292 tracing::trace!(
293 %connection_id,
294 message_id,
295 responding_to = responding_to,
296 ?error,
297 "incoming response: request future dropped",
298 );
299 }
300
301 tracing::trace!(
302 %connection_id,
303 message_id,
304 responding_to,
305 "incoming response: waiting to resume requester"
306 );
307 let _ = requester_resumed.1.await;
308 tracing::trace!(
309 %connection_id,
310 message_id,
311 responding_to,
312 "incoming response: requester resumed"
313 );
314 } else if let Some(tx) = stream_response_channel {
315 let requester_resumed = oneshot::channel();
316 if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
317 tracing::debug!(
318 %connection_id,
319 message_id,
320 responding_to = responding_to,
321 ?error,
322 "incoming stream response: request future dropped",
323 );
324 }
325
326 tracing::debug!(
327 %connection_id,
328 message_id,
329 responding_to,
330 "incoming stream response: waiting to resume requester"
331 );
332 let _ = requester_resumed.1.await;
333 tracing::debug!(
334 %connection_id,
335 message_id,
336 responding_to,
337 "incoming stream response: requester resumed"
338 );
339 } else {
340 let message_type = proto::build_typed_envelope(
341 connection_id.into(),
342 received_at,
343 incoming,
344 )
345 .map(|p| p.payload_type_name());
346 tracing::warn!(
347 %connection_id,
348 message_id,
349 responding_to,
350 message_type,
351 "incoming response: unknown request"
352 );
353 }
354
355 None
356 } else {
357 tracing::trace!(%connection_id, message_id, "incoming message: received");
358 proto::build_typed_envelope(connection_id.into(), received_at, incoming)
359 .or_else(|| {
360 tracing::error!(
361 %connection_id,
362 message_id,
363 "unable to construct a typed envelope"
364 );
365 None
366 })
367 }
368 }
369 });
370 (connection_id, handle_io, incoming_rx.boxed())
371 }
372
373 #[cfg(any(test, feature = "test-support"))]
374 pub fn add_test_connection(
375 self: &Arc<Self>,
376 connection: Connection,
377 executor: gpui::BackgroundExecutor,
378 ) -> (
379 ConnectionId,
380 impl Future<Output = anyhow::Result<()>> + Send + use<>,
381 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
382 ) {
383 let executor = executor.clone();
384 self.add_connection(connection, move |duration| executor.timer(duration))
385 }
386
387 pub fn disconnect(&self, connection_id: ConnectionId) {
388 self.connections.write().remove(&connection_id);
389 }
390
391 #[cfg(any(test, feature = "test-support"))]
392 pub fn reset(&self, epoch: u32) {
393 self.next_connection_id.store(0, SeqCst);
394 self.epoch.store(epoch, SeqCst);
395 }
396
397 pub fn teardown(&self) {
398 self.connections.write().clear();
399 }
400
401 /// Make a request and wait for a response.
402 pub fn request<T: RequestMessage>(
403 &self,
404 receiver_id: ConnectionId,
405 request: T,
406 ) -> impl Future<Output = Result<T::Response>> + use<T> {
407 self.request_internal(None, receiver_id, request)
408 .map_ok(|envelope| envelope.payload)
409 }
410
411 pub fn request_envelope<T: RequestMessage>(
412 &self,
413 receiver_id: ConnectionId,
414 request: T,
415 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
416 self.request_internal(None, receiver_id, request)
417 }
418
419 pub fn forward_request<T: RequestMessage>(
420 &self,
421 sender_id: ConnectionId,
422 receiver_id: ConnectionId,
423 request: T,
424 ) -> impl Future<Output = Result<T::Response>> {
425 let request_start_time = Instant::now();
426 let payload_type = T::NAME;
427 let elapsed_time = move || request_start_time.elapsed().as_millis();
428 tracing::info!(payload_type, "start forwarding request");
429 self.request_internal(Some(sender_id), receiver_id, request)
430 .map_ok(|envelope| envelope.payload)
431 .inspect_err(move |_| {
432 tracing::error!(
433 waiting_for_host_ms = elapsed_time(),
434 payload_type,
435 "error forwarding request"
436 )
437 })
438 .inspect_ok(move |_| {
439 tracing::info!(
440 waiting_for_host_ms = elapsed_time(),
441 payload_type,
442 "finished forwarding request"
443 )
444 })
445 }
446
447 fn request_internal<T: RequestMessage>(
448 &self,
449 original_sender_id: Option<ConnectionId>,
450 receiver_id: ConnectionId,
451 request: T,
452 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
453 let envelope = request.into_envelope(0, None, original_sender_id.map(Into::into));
454 let response = self.request_dynamic(receiver_id, envelope, T::NAME);
455 async move {
456 let (response, received_at) = response.await?;
457 Ok(TypedEnvelope {
458 message_id: response.id,
459 sender_id: receiver_id.into(),
460 original_sender_id: response.original_sender_id,
461 payload: T::Response::from_envelope(response)
462 .context("received response of the wrong type")?,
463 received_at,
464 })
465 }
466 }
467
468 /// Make a request and wait for a response.
469 ///
470 /// The caller must make sure to deserialize the response into the request's
471 /// response type. This interface is only useful in trait objects, where
472 /// generics can't be used. If you have a concrete type, use `request`.
473 pub fn request_dynamic(
474 &self,
475 receiver_id: ConnectionId,
476 mut envelope: proto::Envelope,
477 type_name: &'static str,
478 ) -> impl Future<Output = Result<(proto::Envelope, Instant)>> + use<> {
479 let (tx, rx) = oneshot::channel();
480 let send = self.connection_state(receiver_id).and_then(|connection| {
481 envelope.id = connection.next_message_id.fetch_add(1, SeqCst);
482 connection
483 .response_channels
484 .lock()
485 .as_mut()
486 .context("connection was closed")?
487 .insert(envelope.id, tx);
488 connection
489 .outgoing_tx
490 .unbounded_send(Message::Envelope(envelope))
491 .context("connection was closed")?;
492 Ok(())
493 });
494 async move {
495 send?;
496 let (response, received_at, _barrier) = rx.await.context("connection was closed")?;
497 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
498 return Err(RpcError::from_proto(error, type_name));
499 }
500 Ok((response, received_at))
501 }
502 }
503
504 pub fn request_stream<T: RequestMessage>(
505 &self,
506 receiver_id: ConnectionId,
507 request: T,
508 ) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
509 let (tx, rx) = mpsc::unbounded();
510 let send = self.connection_state(receiver_id).and_then(|connection| {
511 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
512 let stream_response_channels = connection.stream_response_channels.clone();
513 stream_response_channels
514 .lock()
515 .as_mut()
516 .context("connection was closed")?
517 .insert(message_id, tx);
518 connection
519 .outgoing_tx
520 .unbounded_send(Message::Envelope(
521 request.into_envelope(message_id, None, None),
522 ))
523 .context("connection was closed")?;
524 Ok((message_id, stream_response_channels))
525 });
526
527 async move {
528 let (message_id, stream_response_channels) = send?;
529 let stream_response_channels = Arc::downgrade(&stream_response_channels);
530
531 Ok(rx.filter_map(move |(response, _barrier)| {
532 let stream_response_channels = stream_response_channels.clone();
533 future::ready(match response {
534 Ok(response) => {
535 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
536 Some(Err(RpcError::from_proto(error, T::NAME)))
537 } else if let Some(proto::envelope::Payload::EndStream(_)) =
538 &response.payload
539 {
540 // Remove the transmitting end of the response channel to end the stream.
541 if let Some(channels) = stream_response_channels.upgrade() {
542 if let Some(channels) = channels.lock().as_mut() {
543 channels.remove(&message_id);
544 }
545 }
546 None
547 } else {
548 Some(
549 T::Response::from_envelope(response)
550 .context("received response of the wrong type"),
551 )
552 }
553 }
554 Err(error) => Some(Err(error)),
555 })
556 }))
557 }
558 }
559
560 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
561 let connection = self.connection_state(receiver_id)?;
562 let message_id = connection
563 .next_message_id
564 .fetch_add(1, atomic::Ordering::SeqCst);
565 connection.outgoing_tx.unbounded_send(Message::Envelope(
566 message.into_envelope(message_id, None, None),
567 ))?;
568 Ok(())
569 }
570
571 pub fn send_dynamic(&self, receiver_id: ConnectionId, message: proto::Envelope) -> Result<()> {
572 let connection = self.connection_state(receiver_id)?;
573 connection
574 .outgoing_tx
575 .unbounded_send(Message::Envelope(message))?;
576 Ok(())
577 }
578
579 pub fn forward_send<T: EnvelopedMessage>(
580 &self,
581 sender_id: ConnectionId,
582 receiver_id: ConnectionId,
583 message: T,
584 ) -> Result<()> {
585 let connection = self.connection_state(receiver_id)?;
586 let message_id = connection
587 .next_message_id
588 .fetch_add(1, atomic::Ordering::SeqCst);
589 connection
590 .outgoing_tx
591 .unbounded_send(Message::Envelope(message.into_envelope(
592 message_id,
593 None,
594 Some(sender_id.into()),
595 )))?;
596 Ok(())
597 }
598
599 pub fn respond<T: RequestMessage>(
600 &self,
601 receipt: Receipt<T>,
602 response: T::Response,
603 ) -> Result<()> {
604 let connection = self.connection_state(receipt.sender_id.into())?;
605 let message_id = connection
606 .next_message_id
607 .fetch_add(1, atomic::Ordering::SeqCst);
608 connection
609 .outgoing_tx
610 .unbounded_send(Message::Envelope(response.into_envelope(
611 message_id,
612 Some(receipt.message_id),
613 None,
614 )))?;
615 Ok(())
616 }
617
618 pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
619 let connection = self.connection_state(receipt.sender_id.into())?;
620 let message_id = connection
621 .next_message_id
622 .fetch_add(1, atomic::Ordering::SeqCst);
623
624 let message = proto::EndStream {};
625
626 connection
627 .outgoing_tx
628 .unbounded_send(Message::Envelope(message.into_envelope(
629 message_id,
630 Some(receipt.message_id),
631 None,
632 )))?;
633 Ok(())
634 }
635
636 pub fn respond_with_error<T: RequestMessage>(
637 &self,
638 receipt: Receipt<T>,
639 response: proto::Error,
640 ) -> Result<()> {
641 let connection = self.connection_state(receipt.sender_id.into())?;
642 let message_id = connection
643 .next_message_id
644 .fetch_add(1, atomic::Ordering::SeqCst);
645 connection
646 .outgoing_tx
647 .unbounded_send(Message::Envelope(response.into_envelope(
648 message_id,
649 Some(receipt.message_id),
650 None,
651 )))?;
652 Ok(())
653 }
654
655 pub fn respond_with_unhandled_message(
656 &self,
657 sender_id: ConnectionId,
658 request_message_id: u32,
659 message_type_name: &'static str,
660 ) -> Result<()> {
661 let connection = self.connection_state(sender_id)?;
662 let response = ErrorCode::Internal
663 .message(format!("message {} was not handled", message_type_name))
664 .to_proto();
665 let message_id = connection
666 .next_message_id
667 .fetch_add(1, atomic::Ordering::SeqCst);
668 connection
669 .outgoing_tx
670 .unbounded_send(Message::Envelope(response.into_envelope(
671 message_id,
672 Some(request_message_id),
673 None,
674 )))?;
675 Ok(())
676 }
677
678 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
679 let connections = self.connections.read();
680 let connection = connections
681 .get(&connection_id)
682 .with_context(|| format!("no such connection: {connection_id}"))?;
683 Ok(connection.clone())
684 }
685}
686
687impl Serialize for Peer {
688 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
689 where
690 S: serde::Serializer,
691 {
692 let mut state = serializer.serialize_struct("Peer", 2)?;
693 state.serialize_field("connections", &*self.connections.read())?;
694 state.end()
695 }
696}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701 use async_tungstenite::tungstenite::Message as WebSocketMessage;
702 use gpui::TestAppContext;
703
704 fn init_logger() {
705 zlog::init_test();
706 }
707
708 #[gpui::test(iterations = 50)]
709 async fn test_request_response(cx: &mut TestAppContext) {
710 init_logger();
711
712 let executor = cx.executor();
713
714 // create 2 clients connected to 1 server
715 let server = Peer::new(0);
716 let client1 = Peer::new(0);
717 let client2 = Peer::new(0);
718
719 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
720 Connection::in_memory(cx.executor());
721 let (client1_conn_id, io_task1, client1_incoming) =
722 client1.add_test_connection(client1_to_server_conn, cx.executor());
723 let (_, io_task2, server_incoming1) =
724 server.add_test_connection(server_to_client_1_conn, cx.executor());
725
726 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
727 Connection::in_memory(cx.executor());
728 let (client2_conn_id, io_task3, client2_incoming) =
729 client2.add_test_connection(client2_to_server_conn, cx.executor());
730 let (_, io_task4, server_incoming2) =
731 server.add_test_connection(server_to_client_2_conn, cx.executor());
732
733 executor.spawn(io_task1).detach();
734 executor.spawn(io_task2).detach();
735 executor.spawn(io_task3).detach();
736 executor.spawn(io_task4).detach();
737 executor
738 .spawn(handle_messages(server_incoming1, server.clone()))
739 .detach();
740 executor
741 .spawn(handle_messages(client1_incoming, client1.clone()))
742 .detach();
743 executor
744 .spawn(handle_messages(server_incoming2, server.clone()))
745 .detach();
746 executor
747 .spawn(handle_messages(client2_incoming, client2.clone()))
748 .detach();
749
750 assert_eq!(
751 client1
752 .request(client1_conn_id, proto::Ping {},)
753 .await
754 .unwrap(),
755 proto::Ack {}
756 );
757
758 assert_eq!(
759 client2
760 .request(client2_conn_id, proto::Ping {},)
761 .await
762 .unwrap(),
763 proto::Ack {}
764 );
765
766 assert_eq!(
767 client1
768 .request(client1_conn_id, proto::Test { id: 1 },)
769 .await
770 .unwrap(),
771 proto::Test { id: 1 }
772 );
773
774 assert_eq!(
775 client2
776 .request(client2_conn_id, proto::Test { id: 2 })
777 .await
778 .unwrap(),
779 proto::Test { id: 2 }
780 );
781
782 client1.disconnect(client1_conn_id);
783 client2.disconnect(client1_conn_id);
784
785 async fn handle_messages(
786 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
787 peer: Arc<Peer>,
788 ) -> Result<()> {
789 while let Some(envelope) = messages.next().await {
790 let envelope = envelope.into_any();
791 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
792 let receipt = envelope.receipt();
793 peer.respond(receipt, proto::Ack {})?
794 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
795 {
796 peer.respond(envelope.receipt(), envelope.payload.clone())?
797 } else {
798 panic!("unknown message type");
799 }
800 }
801
802 Ok(())
803 }
804 }
805
806 #[gpui::test(iterations = 50)]
807 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
808 let executor = cx.executor();
809 let server = Peer::new(0);
810 let client = Peer::new(0);
811
812 let (client_to_server_conn, server_to_client_conn, _kill) =
813 Connection::in_memory(executor.clone());
814 let (client_to_server_conn_id, io_task1, mut client_incoming) =
815 client.add_test_connection(client_to_server_conn, executor.clone());
816
817 let (server_to_client_conn_id, io_task2, mut server_incoming) =
818 server.add_test_connection(server_to_client_conn, executor.clone());
819
820 executor.spawn(io_task1).detach();
821 executor.spawn(io_task2).detach();
822
823 executor
824 .spawn(async move {
825 let future = server_incoming.next().await;
826 let request = future
827 .unwrap()
828 .into_any()
829 .downcast::<TypedEnvelope<proto::Ping>>()
830 .unwrap();
831
832 server
833 .send(
834 server_to_client_conn_id,
835 ErrorCode::Internal
836 .message("message 1".to_string())
837 .to_proto(),
838 )
839 .unwrap();
840 server
841 .send(
842 server_to_client_conn_id,
843 ErrorCode::Internal
844 .message("message 2".to_string())
845 .to_proto(),
846 )
847 .unwrap();
848 server.respond(request.receipt(), proto::Ack {}).unwrap();
849
850 // Prevent the connection from being dropped
851 server_incoming.next().await;
852 })
853 .detach();
854
855 let events = Arc::new(Mutex::new(Vec::new()));
856
857 let response = client.request(client_to_server_conn_id, proto::Ping {});
858 let response_task = executor.spawn({
859 let events = events.clone();
860 async move {
861 response.await.unwrap();
862 events.lock().push("response".to_string());
863 }
864 });
865
866 executor
867 .spawn({
868 let events = events.clone();
869 async move {
870 let incoming1 = client_incoming
871 .next()
872 .await
873 .unwrap()
874 .into_any()
875 .downcast::<TypedEnvelope<proto::Error>>()
876 .unwrap();
877 events.lock().push(incoming1.payload.message);
878 let incoming2 = client_incoming
879 .next()
880 .await
881 .unwrap()
882 .into_any()
883 .downcast::<TypedEnvelope<proto::Error>>()
884 .unwrap();
885 events.lock().push(incoming2.payload.message);
886
887 // Prevent the connection from being dropped
888 client_incoming.next().await;
889 }
890 })
891 .detach();
892
893 response_task.await;
894 assert_eq!(
895 &*events.lock(),
896 &[
897 "message 1".to_string(),
898 "message 2".to_string(),
899 "response".to_string()
900 ]
901 );
902 }
903
904 #[gpui::test(iterations = 50)]
905 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
906 let executor = cx.executor();
907 let server = Peer::new(0);
908 let client = Peer::new(0);
909
910 let (client_to_server_conn, server_to_client_conn, _kill) =
911 Connection::in_memory(cx.executor());
912 let (client_to_server_conn_id, io_task1, mut client_incoming) =
913 client.add_test_connection(client_to_server_conn, cx.executor());
914 let (server_to_client_conn_id, io_task2, mut server_incoming) =
915 server.add_test_connection(server_to_client_conn, cx.executor());
916
917 executor.spawn(io_task1).detach();
918 executor.spawn(io_task2).detach();
919
920 executor
921 .spawn(async move {
922 let request1 = server_incoming
923 .next()
924 .await
925 .unwrap()
926 .into_any()
927 .downcast::<TypedEnvelope<proto::Ping>>()
928 .unwrap();
929 let request2 = server_incoming
930 .next()
931 .await
932 .unwrap()
933 .into_any()
934 .downcast::<TypedEnvelope<proto::Ping>>()
935 .unwrap();
936
937 server
938 .send(
939 server_to_client_conn_id,
940 ErrorCode::Internal
941 .message("message 1".to_string())
942 .to_proto(),
943 )
944 .unwrap();
945 server
946 .send(
947 server_to_client_conn_id,
948 ErrorCode::Internal
949 .message("message 2".to_string())
950 .to_proto(),
951 )
952 .unwrap();
953 server.respond(request1.receipt(), proto::Ack {}).unwrap();
954 server.respond(request2.receipt(), proto::Ack {}).unwrap();
955
956 // Prevent the connection from being dropped
957 server_incoming.next().await;
958 })
959 .detach();
960
961 let events = Arc::new(Mutex::new(Vec::new()));
962
963 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
964 let request1_task = executor.spawn(request1);
965 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
966 let request2_task = executor.spawn({
967 let events = events.clone();
968 async move {
969 request2.await.unwrap();
970 events.lock().push("response 2".to_string());
971 }
972 });
973
974 executor
975 .spawn({
976 let events = events.clone();
977 async move {
978 let incoming1 = client_incoming
979 .next()
980 .await
981 .unwrap()
982 .into_any()
983 .downcast::<TypedEnvelope<proto::Error>>()
984 .unwrap();
985 events.lock().push(incoming1.payload.message);
986 let incoming2 = client_incoming
987 .next()
988 .await
989 .unwrap()
990 .into_any()
991 .downcast::<TypedEnvelope<proto::Error>>()
992 .unwrap();
993 events.lock().push(incoming2.payload.message);
994
995 // Prevent the connection from being dropped
996 client_incoming.next().await;
997 }
998 })
999 .detach();
1000
1001 // Allow the request to make some progress before dropping it.
1002 cx.executor().simulate_random_delay().await;
1003 drop(request1_task);
1004
1005 request2_task.await;
1006 assert_eq!(
1007 &*events.lock(),
1008 &[
1009 "message 1".to_string(),
1010 "message 2".to_string(),
1011 "response 2".to_string()
1012 ]
1013 );
1014 }
1015
1016 #[gpui::test(iterations = 50)]
1017 async fn test_disconnect(cx: &mut TestAppContext) {
1018 let executor = cx.executor();
1019
1020 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1021
1022 let client = Peer::new(0);
1023 let (connection_id, io_handler, mut incoming) =
1024 client.add_test_connection(client_conn, executor.clone());
1025
1026 let (io_ended_tx, io_ended_rx) = oneshot::channel();
1027 executor
1028 .spawn(async move {
1029 io_handler.await.ok();
1030 io_ended_tx.send(()).unwrap();
1031 })
1032 .detach();
1033
1034 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
1035 executor
1036 .spawn(async move {
1037 incoming.next().await;
1038 messages_ended_tx.send(()).unwrap();
1039 })
1040 .detach();
1041
1042 client.disconnect(connection_id);
1043
1044 let _ = io_ended_rx.await;
1045 let _ = messages_ended_rx.await;
1046 assert!(
1047 server_conn
1048 .send(WebSocketMessage::Binary(vec![].into()))
1049 .await
1050 .is_err()
1051 );
1052 }
1053
1054 #[gpui::test(iterations = 50)]
1055 async fn test_io_error(cx: &mut TestAppContext) {
1056 let executor = cx.executor();
1057 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1058
1059 let client = Peer::new(0);
1060 let (connection_id, io_handler, mut incoming) =
1061 client.add_test_connection(client_conn, executor.clone());
1062 executor.spawn(io_handler).detach();
1063 executor
1064 .spawn(async move { incoming.next().await })
1065 .detach();
1066
1067 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
1068 let _request = server_conn.rx.next().await.unwrap().unwrap();
1069
1070 drop(server_conn);
1071 assert_eq!(
1072 response.await.unwrap_err().to_string(),
1073 "connection was closed"
1074 );
1075 }
1076}