1use crate::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
2
3use super::{
4 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, PeerId, RequestMessage},
5 Connection,
6};
7use anyhow::{anyhow, Context, Result};
8use collections::HashMap;
9use futures::{
10 channel::{mpsc, oneshot},
11 stream::BoxStream,
12 FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
13};
14use parking_lot::{Mutex, RwLock};
15use serde::{ser::SerializeStruct, Serialize};
16use std::{
17 fmt, future,
18 future::Future,
19 marker::PhantomData,
20 sync::atomic::Ordering::SeqCst,
21 sync::{
22 atomic::{self, AtomicU32},
23 Arc,
24 },
25 time::Duration,
26 time::Instant,
27};
28use tracing::instrument;
29
30#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
31pub struct ConnectionId {
32 pub owner_id: u32,
33 pub id: u32,
34}
35
36impl Into<PeerId> for ConnectionId {
37 fn into(self) -> PeerId {
38 PeerId {
39 owner_id: self.owner_id,
40 id: self.id,
41 }
42 }
43}
44
45impl From<PeerId> for ConnectionId {
46 fn from(peer_id: PeerId) -> Self {
47 Self {
48 owner_id: peer_id.owner_id,
49 id: peer_id.id,
50 }
51 }
52}
53
54impl fmt::Display for ConnectionId {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 write!(f, "{}/{}", self.owner_id, self.id)
57 }
58}
59
60pub struct Receipt<T> {
61 pub sender_id: ConnectionId,
62 pub message_id: u32,
63 payload_type: PhantomData<T>,
64}
65
66impl<T> Clone for Receipt<T> {
67 fn clone(&self) -> Self {
68 *self
69 }
70}
71
72impl<T> Copy for Receipt<T> {}
73
74#[derive(Clone, Debug)]
75pub struct TypedEnvelope<T> {
76 pub sender_id: ConnectionId,
77 pub original_sender_id: Option<PeerId>,
78 pub message_id: u32,
79 pub payload: T,
80 pub received_at: Instant,
81}
82
83impl<T> TypedEnvelope<T> {
84 pub fn original_sender_id(&self) -> Result<PeerId> {
85 self.original_sender_id
86 .ok_or_else(|| anyhow!("missing original_sender_id"))
87 }
88}
89
90impl<T: RequestMessage> TypedEnvelope<T> {
91 pub fn receipt(&self) -> Receipt<T> {
92 Receipt {
93 sender_id: self.sender_id,
94 message_id: self.message_id,
95 payload_type: PhantomData,
96 }
97 }
98}
99
100pub struct Peer {
101 epoch: AtomicU32,
102 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
103 next_connection_id: AtomicU32,
104}
105
106#[derive(Clone, Serialize)]
107pub struct ConnectionState {
108 #[serde(skip)]
109 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
110 next_message_id: Arc<AtomicU32>,
111 #[allow(clippy::type_complexity)]
112 #[serde(skip)]
113 response_channels: Arc<
114 Mutex<
115 Option<
116 HashMap<
117 u32,
118 oneshot::Sender<(proto::Envelope, std::time::Instant, oneshot::Sender<()>)>,
119 >,
120 >,
121 >,
122 >,
123 #[allow(clippy::type_complexity)]
124 #[serde(skip)]
125 stream_response_channels: Arc<
126 Mutex<
127 Option<
128 HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
129 >,
130 >,
131 >,
132}
133
134const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
135const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
136pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
137
138impl Peer {
139 pub fn new(epoch: u32) -> Arc<Self> {
140 Arc::new(Self {
141 epoch: AtomicU32::new(epoch),
142 connections: Default::default(),
143 next_connection_id: Default::default(),
144 })
145 }
146
147 pub fn epoch(&self) -> u32 {
148 self.epoch.load(SeqCst)
149 }
150
151 #[instrument(skip_all)]
152 pub fn add_connection<F, Fut, Out>(
153 self: &Arc<Self>,
154 connection: Connection,
155 create_timer: F,
156 ) -> (
157 ConnectionId,
158 impl Future<Output = anyhow::Result<()>> + Send,
159 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
160 )
161 where
162 F: Send + Fn(Duration) -> Fut,
163 Fut: Send + Future<Output = Out>,
164 Out: Send,
165 {
166 // For outgoing messages, use an unbounded channel so that application code
167 // can always send messages without yielding. For incoming messages, use a
168 // bounded channel so that other peers will receive backpressure if they send
169 // messages faster than this peer can process them.
170 #[cfg(any(test, feature = "test-support"))]
171 const INCOMING_BUFFER_SIZE: usize = 1;
172 #[cfg(not(any(test, feature = "test-support")))]
173 const INCOMING_BUFFER_SIZE: usize = 256;
174 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
175 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
176
177 let connection_id = ConnectionId {
178 owner_id: self.epoch.load(SeqCst),
179 id: self.next_connection_id.fetch_add(1, SeqCst),
180 };
181 let connection_state = ConnectionState {
182 outgoing_tx,
183 next_message_id: Default::default(),
184 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
185 stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
186 };
187 let mut writer = MessageStream::new(connection.tx);
188 let mut reader = MessageStream::new(connection.rx);
189
190 let this = self.clone();
191 let response_channels = connection_state.response_channels.clone();
192 let stream_response_channels = connection_state.stream_response_channels.clone();
193
194 let handle_io = async move {
195 tracing::trace!(%connection_id, "handle io future: start");
196
197 let _end_connection = util::defer(|| {
198 response_channels.lock().take();
199 if let Some(channels) = stream_response_channels.lock().take() {
200 for channel in channels.values() {
201 let _ = channel.unbounded_send((
202 Err(anyhow!("connection closed")),
203 oneshot::channel().0,
204 ));
205 }
206 }
207 this.connections.write().remove(&connection_id);
208 tracing::trace!(%connection_id, "handle io future: end");
209 });
210
211 // Send messages on this frequency so the connection isn't closed.
212 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
213 futures::pin_mut!(keepalive_timer);
214
215 // Disconnect if we don't receive messages at least this frequently.
216 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
217 futures::pin_mut!(receive_timeout);
218
219 loop {
220 tracing::trace!(%connection_id, "outer loop iteration start");
221 let read_message = reader.read().fuse();
222 futures::pin_mut!(read_message);
223
224 loop {
225 tracing::trace!(%connection_id, "inner loop iteration start");
226 futures::select_biased! {
227 outgoing = outgoing_rx.next().fuse() => match outgoing {
228 Some(outgoing) => {
229 tracing::trace!(%connection_id, "outgoing rpc message: writing");
230 futures::select_biased! {
231 result = writer.write(outgoing).fuse() => {
232 tracing::trace!(%connection_id, "outgoing rpc message: done writing");
233 result.context("failed to write RPC message")?;
234 tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
235 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
236 }
237 _ = create_timer(WRITE_TIMEOUT).fuse() => {
238 tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
239 Err(anyhow!("timed out writing message"))?;
240 }
241 }
242 }
243 None => {
244 tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
245 return Ok(())
246 },
247 },
248 _ = keepalive_timer => {
249 tracing::trace!(%connection_id, "keepalive interval: pinging");
250 futures::select_biased! {
251 result = writer.write(proto::Message::Ping).fuse() => {
252 tracing::trace!(%connection_id, "keepalive interval: done pinging");
253 result.context("failed to send keepalive")?;
254 tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
255 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
256 }
257 _ = create_timer(WRITE_TIMEOUT).fuse() => {
258 tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
259 Err(anyhow!("timed out sending keepalive"))?;
260 }
261 }
262 }
263 incoming = read_message => {
264 let incoming = incoming.context("error reading rpc message from socket")?;
265 tracing::trace!(%connection_id, "incoming rpc message: received");
266 tracing::trace!(%connection_id, "receive timeout: resetting");
267 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
268 if let (proto::Message::Envelope(incoming), received_at) = incoming {
269 tracing::trace!(%connection_id, "incoming rpc message: processing");
270 futures::select_biased! {
271 result = incoming_tx.send((incoming, received_at)).fuse() => match result {
272 Ok(_) => {
273 tracing::trace!(%connection_id, "incoming rpc message: processed");
274 }
275 Err(_) => {
276 tracing::trace!(%connection_id, "incoming rpc message: channel closed");
277 return Ok(())
278 }
279 },
280 _ = create_timer(WRITE_TIMEOUT).fuse() => {
281 tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
282 Err(anyhow!("timed out processing incoming message"))?
283 }
284 }
285 }
286 break;
287 },
288 _ = receive_timeout => {
289 tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
290 Err(anyhow!("delay between messages too long"))?
291 }
292 }
293 }
294 }
295 };
296
297 let response_channels = connection_state.response_channels.clone();
298 let stream_response_channels = connection_state.stream_response_channels.clone();
299 self.connections
300 .write()
301 .insert(connection_id, connection_state);
302
303 let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
304 let response_channels = response_channels.clone();
305 let stream_response_channels = stream_response_channels.clone();
306 async move {
307 let message_id = incoming.id;
308 tracing::trace!(?incoming, "incoming message future: start");
309 let _end = util::defer(move || {
310 tracing::trace!(%connection_id, message_id, "incoming message future: end");
311 });
312
313 if let Some(responding_to) = incoming.responding_to {
314 tracing::trace!(
315 %connection_id,
316 message_id,
317 responding_to,
318 "incoming response: received"
319 );
320 let response_channel =
321 response_channels.lock().as_mut()?.remove(&responding_to);
322 let stream_response_channel = stream_response_channels
323 .lock()
324 .as_ref()?
325 .get(&responding_to)
326 .cloned();
327
328 if let Some(tx) = response_channel {
329 let requester_resumed = oneshot::channel();
330 if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
331 tracing::trace!(
332 %connection_id,
333 message_id,
334 responding_to = responding_to,
335 ?error,
336 "incoming response: request future dropped",
337 );
338 }
339
340 tracing::trace!(
341 %connection_id,
342 message_id,
343 responding_to,
344 "incoming response: waiting to resume requester"
345 );
346 let _ = requester_resumed.1.await;
347 tracing::trace!(
348 %connection_id,
349 message_id,
350 responding_to,
351 "incoming response: requester resumed"
352 );
353 } else if let Some(tx) = stream_response_channel {
354 let requester_resumed = oneshot::channel();
355 if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
356 tracing::debug!(
357 %connection_id,
358 message_id,
359 responding_to = responding_to,
360 ?error,
361 "incoming stream response: request future dropped",
362 );
363 }
364
365 tracing::debug!(
366 %connection_id,
367 message_id,
368 responding_to,
369 "incoming stream response: waiting to resume requester"
370 );
371 let _ = requester_resumed.1.await;
372 tracing::debug!(
373 %connection_id,
374 message_id,
375 responding_to,
376 "incoming stream response: requester resumed"
377 );
378 } else {
379 let message_type =
380 proto::build_typed_envelope(connection_id, received_at, incoming)
381 .map(|p| p.payload_type_name());
382 tracing::warn!(
383 %connection_id,
384 message_id,
385 responding_to,
386 message_type,
387 "incoming response: unknown request"
388 );
389 }
390
391 None
392 } else {
393 tracing::trace!(%connection_id, message_id, "incoming message: received");
394 proto::build_typed_envelope(connection_id, received_at, incoming).or_else(
395 || {
396 tracing::error!(
397 %connection_id,
398 message_id,
399 "unable to construct a typed envelope"
400 );
401 None
402 },
403 )
404 }
405 }
406 });
407 (connection_id, handle_io, incoming_rx.boxed())
408 }
409
410 #[cfg(any(test, feature = "test-support"))]
411 pub fn add_test_connection(
412 self: &Arc<Self>,
413 connection: Connection,
414 executor: gpui::BackgroundExecutor,
415 ) -> (
416 ConnectionId,
417 impl Future<Output = anyhow::Result<()>> + Send,
418 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
419 ) {
420 let executor = executor.clone();
421 self.add_connection(connection, move |duration| executor.timer(duration))
422 }
423
424 pub fn disconnect(&self, connection_id: ConnectionId) {
425 self.connections.write().remove(&connection_id);
426 }
427
428 #[cfg(any(test, feature = "test-support"))]
429 pub fn reset(&self, epoch: u32) {
430 self.next_connection_id.store(0, SeqCst);
431 self.epoch.store(epoch, SeqCst);
432 }
433
434 pub fn teardown(&self) {
435 self.connections.write().clear();
436 }
437
438 /// Make a request and wait for a response.
439 pub fn request<T: RequestMessage>(
440 &self,
441 receiver_id: ConnectionId,
442 request: T,
443 ) -> impl Future<Output = Result<T::Response>> {
444 self.request_internal(None, receiver_id, request)
445 .map_ok(|envelope| envelope.payload)
446 }
447
448 pub fn request_envelope<T: RequestMessage>(
449 &self,
450 receiver_id: ConnectionId,
451 request: T,
452 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
453 self.request_internal(None, receiver_id, request)
454 }
455
456 pub fn forward_request<T: RequestMessage>(
457 &self,
458 sender_id: ConnectionId,
459 receiver_id: ConnectionId,
460 request: T,
461 ) -> impl Future<Output = Result<T::Response>> {
462 self.request_internal(Some(sender_id), receiver_id, request)
463 .map_ok(|envelope| envelope.payload)
464 }
465
466 fn request_internal<T: RequestMessage>(
467 &self,
468 original_sender_id: Option<ConnectionId>,
469 receiver_id: ConnectionId,
470 request: T,
471 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> {
472 let envelope = request.into_envelope(0, None, original_sender_id.map(Into::into));
473 let response = self.request_dynamic(receiver_id, envelope, T::NAME);
474 async move {
475 let (response, received_at) = response.await?;
476 Ok(TypedEnvelope {
477 message_id: response.id,
478 sender_id: receiver_id,
479 original_sender_id: response.original_sender_id,
480 payload: T::Response::from_envelope(response)
481 .ok_or_else(|| anyhow!("received response of the wrong type"))?,
482 received_at,
483 })
484 }
485 }
486
487 /// Make a request and wait for a response.
488 ///
489 /// The caller must make sure to deserialize the response into the request's
490 /// response type. This interface is only useful in trait objects, where
491 /// generics can't be used. If you have a concrete type, use `request`.
492 pub fn request_dynamic(
493 &self,
494 receiver_id: ConnectionId,
495 mut envelope: proto::Envelope,
496 type_name: &'static str,
497 ) -> impl Future<Output = Result<(proto::Envelope, Instant)>> {
498 let (tx, rx) = oneshot::channel();
499 let send = self.connection_state(receiver_id).and_then(|connection| {
500 envelope.id = connection.next_message_id.fetch_add(1, SeqCst);
501 connection
502 .response_channels
503 .lock()
504 .as_mut()
505 .ok_or_else(|| anyhow!("connection was closed"))?
506 .insert(envelope.id, tx);
507 connection
508 .outgoing_tx
509 .unbounded_send(proto::Message::Envelope(envelope))
510 .map_err(|_| anyhow!("connection was closed"))?;
511 Ok(())
512 });
513 async move {
514 send?;
515 let (response, received_at, _barrier) =
516 rx.await.map_err(|_| anyhow!("connection was closed"))?;
517 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
518 return Err(RpcError::from_proto(&error, type_name));
519 }
520 Ok((response, received_at))
521 }
522 }
523
524 pub fn request_stream<T: RequestMessage>(
525 &self,
526 receiver_id: ConnectionId,
527 request: T,
528 ) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
529 let (tx, rx) = mpsc::unbounded();
530 let send = self.connection_state(receiver_id).and_then(|connection| {
531 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
532 let stream_response_channels = connection.stream_response_channels.clone();
533 stream_response_channels
534 .lock()
535 .as_mut()
536 .ok_or_else(|| anyhow!("connection was closed"))?
537 .insert(message_id, tx);
538 connection
539 .outgoing_tx
540 .unbounded_send(proto::Message::Envelope(
541 request.into_envelope(message_id, None, None),
542 ))
543 .map_err(|_| anyhow!("connection was closed"))?;
544 Ok((message_id, stream_response_channels))
545 });
546
547 async move {
548 let (message_id, stream_response_channels) = send?;
549 let stream_response_channels = Arc::downgrade(&stream_response_channels);
550
551 Ok(rx.filter_map(move |(response, _barrier)| {
552 let stream_response_channels = stream_response_channels.clone();
553 future::ready(match response {
554 Ok(response) => {
555 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
556 Some(Err(anyhow!(
557 "RPC request {} failed - {}",
558 T::NAME,
559 error.message
560 )))
561 } else if let Some(proto::envelope::Payload::EndStream(_)) =
562 &response.payload
563 {
564 // Remove the transmitting end of the response channel to end the stream.
565 if let Some(channels) = stream_response_channels.upgrade() {
566 if let Some(channels) = channels.lock().as_mut() {
567 channels.remove(&message_id);
568 }
569 }
570 None
571 } else {
572 Some(
573 T::Response::from_envelope(response)
574 .ok_or_else(|| anyhow!("received response of the wrong type")),
575 )
576 }
577 }
578 Err(error) => Some(Err(error)),
579 })
580 }))
581 }
582 }
583
584 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
585 let connection = self.connection_state(receiver_id)?;
586 let message_id = connection
587 .next_message_id
588 .fetch_add(1, atomic::Ordering::SeqCst);
589 connection
590 .outgoing_tx
591 .unbounded_send(proto::Message::Envelope(
592 message.into_envelope(message_id, None, None),
593 ))?;
594 Ok(())
595 }
596
597 pub fn forward_send<T: EnvelopedMessage>(
598 &self,
599 sender_id: ConnectionId,
600 receiver_id: ConnectionId,
601 message: T,
602 ) -> Result<()> {
603 let connection = self.connection_state(receiver_id)?;
604 let message_id = connection
605 .next_message_id
606 .fetch_add(1, atomic::Ordering::SeqCst);
607 connection
608 .outgoing_tx
609 .unbounded_send(proto::Message::Envelope(message.into_envelope(
610 message_id,
611 None,
612 Some(sender_id.into()),
613 )))?;
614 Ok(())
615 }
616
617 pub fn respond<T: RequestMessage>(
618 &self,
619 receipt: Receipt<T>,
620 response: T::Response,
621 ) -> Result<()> {
622 let connection = self.connection_state(receipt.sender_id)?;
623 let message_id = connection
624 .next_message_id
625 .fetch_add(1, atomic::Ordering::SeqCst);
626 connection
627 .outgoing_tx
628 .unbounded_send(proto::Message::Envelope(response.into_envelope(
629 message_id,
630 Some(receipt.message_id),
631 None,
632 )))?;
633 Ok(())
634 }
635
636 pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
637 let connection = self.connection_state(receipt.sender_id)?;
638 let message_id = connection
639 .next_message_id
640 .fetch_add(1, atomic::Ordering::SeqCst);
641
642 let message = proto::EndStream {};
643
644 connection
645 .outgoing_tx
646 .unbounded_send(proto::Message::Envelope(message.into_envelope(
647 message_id,
648 Some(receipt.message_id),
649 None,
650 )))?;
651 Ok(())
652 }
653
654 pub fn respond_with_error<T: RequestMessage>(
655 &self,
656 receipt: Receipt<T>,
657 response: proto::Error,
658 ) -> Result<()> {
659 let connection = self.connection_state(receipt.sender_id)?;
660 let message_id = connection
661 .next_message_id
662 .fetch_add(1, atomic::Ordering::SeqCst);
663 connection
664 .outgoing_tx
665 .unbounded_send(proto::Message::Envelope(response.into_envelope(
666 message_id,
667 Some(receipt.message_id),
668 None,
669 )))?;
670 Ok(())
671 }
672
673 pub fn respond_with_unhandled_message(
674 &self,
675 envelope: Box<dyn AnyTypedEnvelope>,
676 ) -> Result<()> {
677 let connection = self.connection_state(envelope.sender_id())?;
678 let response = ErrorCode::Internal
679 .message(format!(
680 "message {} was not handled",
681 envelope.payload_type_name()
682 ))
683 .to_proto();
684 let message_id = connection
685 .next_message_id
686 .fetch_add(1, atomic::Ordering::SeqCst);
687 connection
688 .outgoing_tx
689 .unbounded_send(proto::Message::Envelope(response.into_envelope(
690 message_id,
691 Some(envelope.message_id()),
692 None,
693 )))?;
694 Ok(())
695 }
696
697 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
698 let connections = self.connections.read();
699 let connection = connections
700 .get(&connection_id)
701 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
702 Ok(connection.clone())
703 }
704}
705
706impl Serialize for Peer {
707 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
708 where
709 S: serde::Serializer,
710 {
711 let mut state = serializer.serialize_struct("Peer", 2)?;
712 state.serialize_field("connections", &*self.connections.read())?;
713 state.end()
714 }
715}
716
717#[cfg(test)]
718mod tests {
719 use super::*;
720 use crate::TypedEnvelope;
721 use async_tungstenite::tungstenite::Message as WebSocketMessage;
722 use gpui::TestAppContext;
723
724 fn init_logger() {
725 if std::env::var("RUST_LOG").is_ok() {
726 env_logger::init();
727 }
728 }
729
730 #[gpui::test(iterations = 50)]
731 async fn test_request_response(cx: &mut TestAppContext) {
732 init_logger();
733
734 let executor = cx.executor();
735
736 // create 2 clients connected to 1 server
737 let server = Peer::new(0);
738 let client1 = Peer::new(0);
739 let client2 = Peer::new(0);
740
741 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
742 Connection::in_memory(cx.executor());
743 let (client1_conn_id, io_task1, client1_incoming) =
744 client1.add_test_connection(client1_to_server_conn, cx.executor());
745 let (_, io_task2, server_incoming1) =
746 server.add_test_connection(server_to_client_1_conn, cx.executor());
747
748 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
749 Connection::in_memory(cx.executor());
750 let (client2_conn_id, io_task3, client2_incoming) =
751 client2.add_test_connection(client2_to_server_conn, cx.executor());
752 let (_, io_task4, server_incoming2) =
753 server.add_test_connection(server_to_client_2_conn, cx.executor());
754
755 executor.spawn(io_task1).detach();
756 executor.spawn(io_task2).detach();
757 executor.spawn(io_task3).detach();
758 executor.spawn(io_task4).detach();
759 executor
760 .spawn(handle_messages(server_incoming1, server.clone()))
761 .detach();
762 executor
763 .spawn(handle_messages(client1_incoming, client1.clone()))
764 .detach();
765 executor
766 .spawn(handle_messages(server_incoming2, server.clone()))
767 .detach();
768 executor
769 .spawn(handle_messages(client2_incoming, client2.clone()))
770 .detach();
771
772 assert_eq!(
773 client1
774 .request(client1_conn_id, proto::Ping {},)
775 .await
776 .unwrap(),
777 proto::Ack {}
778 );
779
780 assert_eq!(
781 client2
782 .request(client2_conn_id, proto::Ping {},)
783 .await
784 .unwrap(),
785 proto::Ack {}
786 );
787
788 assert_eq!(
789 client1
790 .request(client1_conn_id, proto::Test { id: 1 },)
791 .await
792 .unwrap(),
793 proto::Test { id: 1 }
794 );
795
796 assert_eq!(
797 client2
798 .request(client2_conn_id, proto::Test { id: 2 })
799 .await
800 .unwrap(),
801 proto::Test { id: 2 }
802 );
803
804 client1.disconnect(client1_conn_id);
805 client2.disconnect(client1_conn_id);
806
807 async fn handle_messages(
808 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
809 peer: Arc<Peer>,
810 ) -> Result<()> {
811 while let Some(envelope) = messages.next().await {
812 let envelope = envelope.into_any();
813 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
814 let receipt = envelope.receipt();
815 peer.respond(receipt, proto::Ack {})?
816 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
817 {
818 peer.respond(envelope.receipt(), envelope.payload.clone())?
819 } else {
820 panic!("unknown message type");
821 }
822 }
823
824 Ok(())
825 }
826 }
827
828 #[gpui::test(iterations = 50)]
829 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
830 let executor = cx.executor();
831 let server = Peer::new(0);
832 let client = Peer::new(0);
833
834 let (client_to_server_conn, server_to_client_conn, _kill) =
835 Connection::in_memory(executor.clone());
836 let (client_to_server_conn_id, io_task1, mut client_incoming) =
837 client.add_test_connection(client_to_server_conn, executor.clone());
838
839 let (server_to_client_conn_id, io_task2, mut server_incoming) =
840 server.add_test_connection(server_to_client_conn, executor.clone());
841
842 executor.spawn(io_task1).detach();
843 executor.spawn(io_task2).detach();
844
845 executor
846 .spawn(async move {
847 let future = server_incoming.next().await;
848 let request = future
849 .unwrap()
850 .into_any()
851 .downcast::<TypedEnvelope<proto::Ping>>()
852 .unwrap();
853
854 server
855 .send(
856 server_to_client_conn_id,
857 ErrorCode::Internal
858 .message("message 1".to_string())
859 .to_proto(),
860 )
861 .unwrap();
862 server
863 .send(
864 server_to_client_conn_id,
865 ErrorCode::Internal
866 .message("message 2".to_string())
867 .to_proto(),
868 )
869 .unwrap();
870 server.respond(request.receipt(), proto::Ack {}).unwrap();
871
872 // Prevent the connection from being dropped
873 server_incoming.next().await;
874 })
875 .detach();
876
877 let events = Arc::new(Mutex::new(Vec::new()));
878
879 let response = client.request(client_to_server_conn_id, proto::Ping {});
880 let response_task = executor.spawn({
881 let events = events.clone();
882 async move {
883 response.await.unwrap();
884 events.lock().push("response".to_string());
885 }
886 });
887
888 executor
889 .spawn({
890 let events = events.clone();
891 async move {
892 let incoming1 = client_incoming
893 .next()
894 .await
895 .unwrap()
896 .into_any()
897 .downcast::<TypedEnvelope<proto::Error>>()
898 .unwrap();
899 events.lock().push(incoming1.payload.message);
900 let incoming2 = client_incoming
901 .next()
902 .await
903 .unwrap()
904 .into_any()
905 .downcast::<TypedEnvelope<proto::Error>>()
906 .unwrap();
907 events.lock().push(incoming2.payload.message);
908
909 // Prevent the connection from being dropped
910 client_incoming.next().await;
911 }
912 })
913 .detach();
914
915 response_task.await;
916 assert_eq!(
917 &*events.lock(),
918 &[
919 "message 1".to_string(),
920 "message 2".to_string(),
921 "response".to_string()
922 ]
923 );
924 }
925
926 #[gpui::test(iterations = 50)]
927 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
928 let executor = cx.executor();
929 let server = Peer::new(0);
930 let client = Peer::new(0);
931
932 let (client_to_server_conn, server_to_client_conn, _kill) =
933 Connection::in_memory(cx.executor());
934 let (client_to_server_conn_id, io_task1, mut client_incoming) =
935 client.add_test_connection(client_to_server_conn, cx.executor());
936 let (server_to_client_conn_id, io_task2, mut server_incoming) =
937 server.add_test_connection(server_to_client_conn, cx.executor());
938
939 executor.spawn(io_task1).detach();
940 executor.spawn(io_task2).detach();
941
942 executor
943 .spawn(async move {
944 let request1 = server_incoming
945 .next()
946 .await
947 .unwrap()
948 .into_any()
949 .downcast::<TypedEnvelope<proto::Ping>>()
950 .unwrap();
951 let request2 = server_incoming
952 .next()
953 .await
954 .unwrap()
955 .into_any()
956 .downcast::<TypedEnvelope<proto::Ping>>()
957 .unwrap();
958
959 server
960 .send(
961 server_to_client_conn_id,
962 ErrorCode::Internal
963 .message("message 1".to_string())
964 .to_proto(),
965 )
966 .unwrap();
967 server
968 .send(
969 server_to_client_conn_id,
970 ErrorCode::Internal
971 .message("message 2".to_string())
972 .to_proto(),
973 )
974 .unwrap();
975 server.respond(request1.receipt(), proto::Ack {}).unwrap();
976 server.respond(request2.receipt(), proto::Ack {}).unwrap();
977
978 // Prevent the connection from being dropped
979 server_incoming.next().await;
980 })
981 .detach();
982
983 let events = Arc::new(Mutex::new(Vec::new()));
984
985 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
986 let request1_task = executor.spawn(request1);
987 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
988 let request2_task = executor.spawn({
989 let events = events.clone();
990 async move {
991 request2.await.unwrap();
992 events.lock().push("response 2".to_string());
993 }
994 });
995
996 executor
997 .spawn({
998 let events = events.clone();
999 async move {
1000 let incoming1 = client_incoming
1001 .next()
1002 .await
1003 .unwrap()
1004 .into_any()
1005 .downcast::<TypedEnvelope<proto::Error>>()
1006 .unwrap();
1007 events.lock().push(incoming1.payload.message);
1008 let incoming2 = client_incoming
1009 .next()
1010 .await
1011 .unwrap()
1012 .into_any()
1013 .downcast::<TypedEnvelope<proto::Error>>()
1014 .unwrap();
1015 events.lock().push(incoming2.payload.message);
1016
1017 // Prevent the connection from being dropped
1018 client_incoming.next().await;
1019 }
1020 })
1021 .detach();
1022
1023 // Allow the request to make some progress before dropping it.
1024 cx.executor().simulate_random_delay().await;
1025 drop(request1_task);
1026
1027 request2_task.await;
1028 assert_eq!(
1029 &*events.lock(),
1030 &[
1031 "message 1".to_string(),
1032 "message 2".to_string(),
1033 "response 2".to_string()
1034 ]
1035 );
1036 }
1037
1038 #[gpui::test(iterations = 50)]
1039 async fn test_disconnect(cx: &mut TestAppContext) {
1040 let executor = cx.executor();
1041
1042 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1043
1044 let client = Peer::new(0);
1045 let (connection_id, io_handler, mut incoming) =
1046 client.add_test_connection(client_conn, executor.clone());
1047
1048 let (io_ended_tx, io_ended_rx) = oneshot::channel();
1049 executor
1050 .spawn(async move {
1051 io_handler.await.ok();
1052 io_ended_tx.send(()).unwrap();
1053 })
1054 .detach();
1055
1056 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
1057 executor
1058 .spawn(async move {
1059 incoming.next().await;
1060 messages_ended_tx.send(()).unwrap();
1061 })
1062 .detach();
1063
1064 client.disconnect(connection_id);
1065
1066 let _ = io_ended_rx.await;
1067 let _ = messages_ended_rx.await;
1068 assert!(server_conn
1069 .send(WebSocketMessage::Binary(vec![]))
1070 .await
1071 .is_err());
1072 }
1073
1074 #[gpui::test(iterations = 50)]
1075 async fn test_io_error(cx: &mut TestAppContext) {
1076 let executor = cx.executor();
1077 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1078
1079 let client = Peer::new(0);
1080 let (connection_id, io_handler, mut incoming) =
1081 client.add_test_connection(client_conn, executor.clone());
1082 executor.spawn(io_handler).detach();
1083 executor
1084 .spawn(async move { incoming.next().await })
1085 .detach();
1086
1087 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
1088 let _request = server_conn.rx.next().await.unwrap().unwrap();
1089
1090 drop(server_conn);
1091 assert_eq!(
1092 response.await.unwrap_err().to_string(),
1093 "connection was closed"
1094 );
1095 }
1096}