1use super::{
2 Connection,
3 message_stream::{Message, MessageStream},
4 proto::{
5 self, AnyTypedEnvelope, EnvelopedMessage, PeerId, Receipt, RequestMessage, TypedEnvelope,
6 },
7};
8use anyhow::{Context as _, Result, anyhow};
9use collections::HashMap;
10use futures::{
11 FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
12 channel::{mpsc, oneshot},
13 stream::BoxStream,
14};
15use parking_lot::{Mutex, RwLock};
16use proto::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
17use serde::{Serialize, ser::SerializeStruct};
18use std::{
19 fmt, future,
20 future::Future,
21 sync::atomic::Ordering::SeqCst,
22 sync::{
23 Arc,
24 atomic::{self, AtomicU32},
25 },
26 time::Duration,
27 time::Instant,
28};
29
30#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
31pub struct ConnectionId {
32 pub owner_id: u32,
33 pub id: u32,
34}
35
36impl From<ConnectionId> for PeerId {
37 fn from(id: ConnectionId) -> Self {
38 PeerId {
39 owner_id: id.owner_id,
40 id: id.id,
41 }
42 }
43}
44
45impl From<PeerId> for ConnectionId {
46 fn from(peer_id: PeerId) -> Self {
47 Self {
48 owner_id: peer_id.owner_id,
49 id: peer_id.id,
50 }
51 }
52}
53
54impl fmt::Display for ConnectionId {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 write!(f, "{}/{}", self.owner_id, self.id)
57 }
58}
59
60pub struct Peer {
61 epoch: AtomicU32,
62 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
63 next_connection_id: AtomicU32,
64}
65
66#[derive(Clone, Serialize)]
67pub struct ConnectionState {
68 #[serde(skip)]
69 outgoing_tx: mpsc::UnboundedSender<Message>,
70 next_message_id: Arc<AtomicU32>,
71 #[allow(clippy::type_complexity)]
72 #[serde(skip)]
73 response_channels: Arc<
74 Mutex<
75 Option<
76 HashMap<
77 u32,
78 oneshot::Sender<(proto::Envelope, std::time::Instant, oneshot::Sender<()>)>,
79 >,
80 >,
81 >,
82 >,
83 #[allow(clippy::type_complexity)]
84 #[serde(skip)]
85 stream_response_channels: Arc<
86 Mutex<
87 Option<
88 HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
89 >,
90 >,
91 >,
92}
93
94const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
95const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
96pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
97
98impl Peer {
99 pub fn new(epoch: u32) -> Arc<Self> {
100 Arc::new(Self {
101 epoch: AtomicU32::new(epoch),
102 connections: Default::default(),
103 next_connection_id: Default::default(),
104 })
105 }
106
107 pub fn epoch(&self) -> u32 {
108 self.epoch.load(SeqCst)
109 }
110
111 pub fn add_connection<F, Fut, Out>(
112 self: &Arc<Self>,
113 connection: Connection,
114 create_timer: F,
115 ) -> (
116 ConnectionId,
117 impl Future<Output = anyhow::Result<()>> + Send + use<F, Fut, Out>,
118 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
119 )
120 where
121 F: Send + Fn(Duration) -> Fut,
122 Fut: Send + Future<Output = Out>,
123 Out: Send,
124 {
125 // For outgoing messages, use an unbounded channel so that application code
126 // can always send messages without yielding. For incoming messages, use a
127 // bounded channel so that other peers will receive backpressure if they send
128 // messages faster than this peer can process them.
129 #[cfg(any(test, feature = "test-support"))]
130 const INCOMING_BUFFER_SIZE: usize = 1;
131 #[cfg(not(any(test, feature = "test-support")))]
132 const INCOMING_BUFFER_SIZE: usize = 256;
133 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
134 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
135
136 let connection_id = ConnectionId {
137 owner_id: self.epoch.load(SeqCst),
138 id: self.next_connection_id.fetch_add(1, SeqCst),
139 };
140 let connection_state = ConnectionState {
141 outgoing_tx,
142 next_message_id: Default::default(),
143 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
144 stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
145 };
146 let mut writer = MessageStream::new(connection.tx);
147 let mut reader = MessageStream::new(connection.rx);
148
149 let this = self.clone();
150 let response_channels = connection_state.response_channels.clone();
151 let stream_response_channels = connection_state.stream_response_channels.clone();
152
153 let handle_io = async move {
154 tracing::trace!(%connection_id, "handle io future: start");
155
156 let _end_connection = util::defer(|| {
157 response_channels.lock().take();
158 if let Some(channels) = stream_response_channels.lock().take() {
159 for channel in channels.values() {
160 let _ = channel.unbounded_send((
161 Err(anyhow!("connection closed")),
162 oneshot::channel().0,
163 ));
164 }
165 }
166 this.connections.write().remove(&connection_id);
167 tracing::trace!(%connection_id, "handle io future: end");
168 });
169
170 // Send messages on this frequency so the connection isn't closed.
171 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
172 futures::pin_mut!(keepalive_timer);
173
174 // Disconnect if we don't receive messages at least this frequently.
175 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
176 futures::pin_mut!(receive_timeout);
177
178 loop {
179 tracing::trace!(%connection_id, "outer loop iteration start");
180 let read_message = reader.read().fuse();
181 futures::pin_mut!(read_message);
182
183 loop {
184 tracing::trace!(%connection_id, "inner loop iteration start");
185 futures::select_biased! {
186 outgoing = outgoing_rx.next().fuse() => match outgoing {
187 Some(outgoing) => {
188 tracing::trace!(%connection_id, "outgoing rpc message: writing");
189 futures::select_biased! {
190 result = writer.write(outgoing).fuse() => {
191 tracing::trace!(%connection_id, "outgoing rpc message: done writing");
192 result.context("failed to write RPC message")?;
193 tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
194 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
195 }
196 _ = create_timer(WRITE_TIMEOUT).fuse() => {
197 tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
198 anyhow::bail!("timed out writing message");
199 }
200 }
201 }
202 None => {
203 tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
204 return Ok(())
205 },
206 },
207 _ = keepalive_timer => {
208 tracing::trace!(%connection_id, "keepalive interval: pinging");
209 futures::select_biased! {
210 result = writer.write(Message::Ping).fuse() => {
211 tracing::trace!(%connection_id, "keepalive interval: done pinging");
212 result.context("failed to send keepalive")?;
213 tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
214 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
215 }
216 _ = create_timer(WRITE_TIMEOUT).fuse() => {
217 tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
218 anyhow::bail!("timed out sending keepalive");
219 }
220 }
221 }
222 incoming = read_message => {
223 let incoming = incoming.context("error reading rpc message from socket")?;
224 tracing::trace!(%connection_id, "incoming rpc message: received");
225 tracing::trace!(%connection_id, "receive timeout: resetting");
226 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
227 if let (Message::Envelope(incoming), received_at) = incoming {
228 tracing::trace!(%connection_id, "incoming rpc message: processing");
229 futures::select_biased! {
230 result = incoming_tx.send((incoming, received_at)).fuse() => match result {
231 Ok(_) => {
232 tracing::trace!(%connection_id, "incoming rpc message: processed");
233 }
234 Err(_) => {
235 tracing::trace!(%connection_id, "incoming rpc message: channel closed");
236 return Ok(())
237 }
238 },
239 _ = create_timer(WRITE_TIMEOUT).fuse() => {
240 tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
241 anyhow::bail!("timed out processing incoming message");
242 }
243 }
244 }
245 break;
246 },
247 _ = receive_timeout => {
248 tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
249 anyhow::bail!("delay between messages too long");
250 }
251 }
252 }
253 }
254 };
255
256 let response_channels = connection_state.response_channels.clone();
257 let stream_response_channels = connection_state.stream_response_channels.clone();
258 self.connections
259 .write()
260 .insert(connection_id, connection_state);
261
262 let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
263 let response_channels = response_channels.clone();
264 let stream_response_channels = stream_response_channels.clone();
265 async move {
266 let message_id = incoming.id;
267 tracing::trace!(?incoming, "incoming message future: start");
268 let _end = util::defer(move || {
269 tracing::trace!(%connection_id, message_id, "incoming message future: end");
270 });
271
272 if let Some(responding_to) = incoming.responding_to {
273 tracing::trace!(
274 %connection_id,
275 message_id,
276 responding_to,
277 "incoming response: received"
278 );
279 let response_channel =
280 response_channels.lock().as_mut()?.remove(&responding_to);
281 let stream_response_channel = stream_response_channels
282 .lock()
283 .as_ref()?
284 .get(&responding_to)
285 .cloned();
286
287 if let Some(tx) = response_channel {
288 let requester_resumed = oneshot::channel();
289 if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
290 tracing::trace!(
291 %connection_id,
292 message_id,
293 responding_to = responding_to,
294 ?error,
295 "incoming response: request future dropped",
296 );
297 }
298
299 tracing::trace!(
300 %connection_id,
301 message_id,
302 responding_to,
303 "incoming response: waiting to resume requester"
304 );
305 let _ = requester_resumed.1.await;
306 tracing::trace!(
307 %connection_id,
308 message_id,
309 responding_to,
310 "incoming response: requester resumed"
311 );
312 } else if let Some(tx) = stream_response_channel {
313 let requester_resumed = oneshot::channel();
314 if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
315 tracing::debug!(
316 %connection_id,
317 message_id,
318 responding_to = responding_to,
319 ?error,
320 "incoming stream response: request future dropped",
321 );
322 }
323
324 tracing::debug!(
325 %connection_id,
326 message_id,
327 responding_to,
328 "incoming stream response: waiting to resume requester"
329 );
330 let _ = requester_resumed.1.await;
331 tracing::debug!(
332 %connection_id,
333 message_id,
334 responding_to,
335 "incoming stream response: requester resumed"
336 );
337 } else {
338 let message_type = proto::build_typed_envelope(
339 connection_id.into(),
340 received_at,
341 incoming,
342 )
343 .map(|p| p.payload_type_name());
344 tracing::warn!(
345 %connection_id,
346 message_id,
347 responding_to,
348 message_type,
349 "incoming response: unknown request"
350 );
351 }
352
353 None
354 } else {
355 tracing::trace!(%connection_id, message_id, "incoming message: received");
356 proto::build_typed_envelope(connection_id.into(), received_at, incoming)
357 .or_else(|| {
358 tracing::error!(
359 %connection_id,
360 message_id,
361 "unable to construct a typed envelope"
362 );
363 None
364 })
365 }
366 }
367 });
368 (connection_id, handle_io, incoming_rx.boxed())
369 }
370
371 #[cfg(any(test, feature = "test-support"))]
372 pub fn add_test_connection(
373 self: &Arc<Self>,
374 connection: Connection,
375 executor: gpui::BackgroundExecutor,
376 ) -> (
377 ConnectionId,
378 impl Future<Output = anyhow::Result<()>> + Send + use<>,
379 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
380 ) {
381 let executor = executor.clone();
382 self.add_connection(connection, move |duration| executor.timer(duration))
383 }
384
385 pub fn disconnect(&self, connection_id: ConnectionId) {
386 self.connections.write().remove(&connection_id);
387 }
388
389 #[cfg(any(test, feature = "test-support"))]
390 pub fn reset(&self, epoch: u32) {
391 self.next_connection_id.store(0, SeqCst);
392 self.epoch.store(epoch, SeqCst);
393 }
394
395 pub fn teardown(&self) {
396 self.connections.write().clear();
397 }
398
399 /// Make a request and wait for a response.
400 pub fn request<T: RequestMessage>(
401 &self,
402 receiver_id: ConnectionId,
403 request: T,
404 ) -> impl Future<Output = Result<T::Response>> + use<T> {
405 self.request_internal(None, receiver_id, request)
406 .map_ok(|envelope| envelope.payload)
407 }
408
409 pub fn request_envelope<T: RequestMessage>(
410 &self,
411 receiver_id: ConnectionId,
412 request: T,
413 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
414 self.request_internal(None, receiver_id, request)
415 }
416
417 pub fn forward_request<T: RequestMessage>(
418 &self,
419 sender_id: ConnectionId,
420 receiver_id: ConnectionId,
421 request: T,
422 ) -> impl Future<Output = Result<T::Response>> {
423 self.request_internal(Some(sender_id), receiver_id, request)
424 .map_ok(|envelope| envelope.payload)
425 }
426
427 fn request_internal<T: RequestMessage>(
428 &self,
429 original_sender_id: Option<ConnectionId>,
430 receiver_id: ConnectionId,
431 request: T,
432 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
433 let envelope = request.into_envelope(0, None, original_sender_id.map(Into::into));
434 let response = self.request_dynamic(receiver_id, envelope, T::NAME);
435 async move {
436 let (response, received_at) = response.await?;
437 Ok(TypedEnvelope {
438 message_id: response.id,
439 sender_id: receiver_id.into(),
440 original_sender_id: response.original_sender_id,
441 payload: T::Response::from_envelope(response)
442 .context("received response of the wrong type")?,
443 received_at,
444 })
445 }
446 }
447
448 /// Make a request and wait for a response.
449 ///
450 /// The caller must make sure to deserialize the response into the request's
451 /// response type. This interface is only useful in trait objects, where
452 /// generics can't be used. If you have a concrete type, use `request`.
453 pub fn request_dynamic(
454 &self,
455 receiver_id: ConnectionId,
456 mut envelope: proto::Envelope,
457 type_name: &'static str,
458 ) -> impl Future<Output = Result<(proto::Envelope, Instant)>> + use<> {
459 let (tx, rx) = oneshot::channel();
460 let send = self.connection_state(receiver_id).and_then(|connection| {
461 envelope.id = connection.next_message_id.fetch_add(1, SeqCst);
462 connection
463 .response_channels
464 .lock()
465 .as_mut()
466 .context("connection was closed")?
467 .insert(envelope.id, tx);
468 connection
469 .outgoing_tx
470 .unbounded_send(Message::Envelope(envelope))
471 .context("connection was closed")?;
472 Ok(())
473 });
474 async move {
475 send?;
476 let (response, received_at, _barrier) = rx.await.context("connection was closed")?;
477 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
478 return Err(RpcError::from_proto(error, type_name));
479 }
480 Ok((response, received_at))
481 }
482 }
483
484 pub fn request_stream<T: RequestMessage>(
485 &self,
486 receiver_id: ConnectionId,
487 request: T,
488 ) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
489 let (tx, rx) = mpsc::unbounded();
490 let send = self.connection_state(receiver_id).and_then(|connection| {
491 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
492 let stream_response_channels = connection.stream_response_channels.clone();
493 stream_response_channels
494 .lock()
495 .as_mut()
496 .context("connection was closed")?
497 .insert(message_id, tx);
498 connection
499 .outgoing_tx
500 .unbounded_send(Message::Envelope(
501 request.into_envelope(message_id, None, None),
502 ))
503 .context("connection was closed")?;
504 Ok((message_id, stream_response_channels))
505 });
506
507 async move {
508 let (message_id, stream_response_channels) = send?;
509 let stream_response_channels = Arc::downgrade(&stream_response_channels);
510
511 Ok(rx.filter_map(move |(response, _barrier)| {
512 let stream_response_channels = stream_response_channels.clone();
513 future::ready(match response {
514 Ok(response) => {
515 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
516 Some(Err(RpcError::from_proto(error, T::NAME)))
517 } else if let Some(proto::envelope::Payload::EndStream(_)) =
518 &response.payload
519 {
520 // Remove the transmitting end of the response channel to end the stream.
521 if let Some(channels) = stream_response_channels.upgrade()
522 && let Some(channels) = channels.lock().as_mut()
523 {
524 channels.remove(&message_id);
525 }
526 None
527 } else {
528 Some(
529 T::Response::from_envelope(response)
530 .context("received response of the wrong type"),
531 )
532 }
533 }
534 Err(error) => Some(Err(error)),
535 })
536 }))
537 }
538 }
539
540 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
541 let connection = self.connection_state(receiver_id)?;
542 let message_id = connection
543 .next_message_id
544 .fetch_add(1, atomic::Ordering::SeqCst);
545 connection.outgoing_tx.unbounded_send(Message::Envelope(
546 message.into_envelope(message_id, None, None),
547 ))?;
548 Ok(())
549 }
550
551 pub fn send_dynamic(&self, receiver_id: ConnectionId, message: proto::Envelope) -> Result<()> {
552 let connection = self.connection_state(receiver_id)?;
553 connection
554 .outgoing_tx
555 .unbounded_send(Message::Envelope(message))?;
556 Ok(())
557 }
558
559 pub fn forward_send<T: EnvelopedMessage>(
560 &self,
561 sender_id: ConnectionId,
562 receiver_id: ConnectionId,
563 message: T,
564 ) -> Result<()> {
565 let connection = self.connection_state(receiver_id)?;
566 let message_id = connection
567 .next_message_id
568 .fetch_add(1, atomic::Ordering::SeqCst);
569 connection
570 .outgoing_tx
571 .unbounded_send(Message::Envelope(message.into_envelope(
572 message_id,
573 None,
574 Some(sender_id.into()),
575 )))?;
576 Ok(())
577 }
578
579 pub fn respond<T: RequestMessage>(
580 &self,
581 receipt: Receipt<T>,
582 response: T::Response,
583 ) -> Result<()> {
584 let connection = self.connection_state(receipt.sender_id.into())?;
585 let message_id = connection
586 .next_message_id
587 .fetch_add(1, atomic::Ordering::SeqCst);
588 connection
589 .outgoing_tx
590 .unbounded_send(Message::Envelope(response.into_envelope(
591 message_id,
592 Some(receipt.message_id),
593 None,
594 )))?;
595 Ok(())
596 }
597
598 pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
599 let connection = self.connection_state(receipt.sender_id.into())?;
600 let message_id = connection
601 .next_message_id
602 .fetch_add(1, atomic::Ordering::SeqCst);
603
604 let message = proto::EndStream {};
605
606 connection
607 .outgoing_tx
608 .unbounded_send(Message::Envelope(message.into_envelope(
609 message_id,
610 Some(receipt.message_id),
611 None,
612 )))?;
613 Ok(())
614 }
615
616 pub fn respond_with_error<T: RequestMessage>(
617 &self,
618 receipt: Receipt<T>,
619 response: proto::Error,
620 ) -> Result<()> {
621 let connection = self.connection_state(receipt.sender_id.into())?;
622 let message_id = connection
623 .next_message_id
624 .fetch_add(1, atomic::Ordering::SeqCst);
625 connection
626 .outgoing_tx
627 .unbounded_send(Message::Envelope(response.into_envelope(
628 message_id,
629 Some(receipt.message_id),
630 None,
631 )))?;
632 Ok(())
633 }
634
635 pub fn respond_with_unhandled_message(
636 &self,
637 sender_id: ConnectionId,
638 request_message_id: u32,
639 message_type_name: &'static str,
640 ) -> Result<()> {
641 let connection = self.connection_state(sender_id)?;
642 let response = ErrorCode::Internal
643 .message(format!("message {} was not handled", message_type_name))
644 .to_proto();
645 let message_id = connection
646 .next_message_id
647 .fetch_add(1, atomic::Ordering::SeqCst);
648 connection
649 .outgoing_tx
650 .unbounded_send(Message::Envelope(response.into_envelope(
651 message_id,
652 Some(request_message_id),
653 None,
654 )))?;
655 Ok(())
656 }
657
658 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
659 let connections = self.connections.read();
660 let connection = connections
661 .get(&connection_id)
662 .with_context(|| format!("no such connection: {connection_id}"))?;
663 Ok(connection.clone())
664 }
665}
666
667impl Serialize for Peer {
668 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
669 where
670 S: serde::Serializer,
671 {
672 let mut state = serializer.serialize_struct("Peer", 2)?;
673 state.serialize_field("connections", &*self.connections.read())?;
674 state.end()
675 }
676}
677
678#[cfg(test)]
679mod tests {
680 use super::*;
681 use async_tungstenite::tungstenite::Message as WebSocketMessage;
682 use gpui::TestAppContext;
683
684 fn init_logger() {
685 zlog::init_test();
686 }
687
688 #[gpui::test(iterations = 50)]
689 async fn test_request_response(cx: &mut TestAppContext) {
690 init_logger();
691
692 let executor = cx.executor();
693
694 // create 2 clients connected to 1 server
695 let server = Peer::new(0);
696 let client1 = Peer::new(0);
697 let client2 = Peer::new(0);
698
699 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
700 Connection::in_memory(cx.executor());
701 let (client1_conn_id, io_task1, client1_incoming) =
702 client1.add_test_connection(client1_to_server_conn, cx.executor());
703 let (_, io_task2, server_incoming1) =
704 server.add_test_connection(server_to_client_1_conn, cx.executor());
705
706 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
707 Connection::in_memory(cx.executor());
708 let (client2_conn_id, io_task3, client2_incoming) =
709 client2.add_test_connection(client2_to_server_conn, cx.executor());
710 let (_, io_task4, server_incoming2) =
711 server.add_test_connection(server_to_client_2_conn, cx.executor());
712
713 executor.spawn(io_task1).detach();
714 executor.spawn(io_task2).detach();
715 executor.spawn(io_task3).detach();
716 executor.spawn(io_task4).detach();
717 executor
718 .spawn(handle_messages(server_incoming1, server.clone()))
719 .detach();
720 executor
721 .spawn(handle_messages(client1_incoming, client1.clone()))
722 .detach();
723 executor
724 .spawn(handle_messages(server_incoming2, server.clone()))
725 .detach();
726 executor
727 .spawn(handle_messages(client2_incoming, client2.clone()))
728 .detach();
729
730 assert_eq!(
731 client1
732 .request(client1_conn_id, proto::Ping {},)
733 .await
734 .unwrap(),
735 proto::Ack {}
736 );
737
738 assert_eq!(
739 client2
740 .request(client2_conn_id, proto::Ping {},)
741 .await
742 .unwrap(),
743 proto::Ack {}
744 );
745
746 assert_eq!(
747 client1
748 .request(client1_conn_id, proto::Test { id: 1 },)
749 .await
750 .unwrap(),
751 proto::Test { id: 1 }
752 );
753
754 assert_eq!(
755 client2
756 .request(client2_conn_id, proto::Test { id: 2 })
757 .await
758 .unwrap(),
759 proto::Test { id: 2 }
760 );
761
762 client1.disconnect(client1_conn_id);
763 client2.disconnect(client1_conn_id);
764
765 async fn handle_messages(
766 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
767 peer: Arc<Peer>,
768 ) -> Result<()> {
769 while let Some(envelope) = messages.next().await {
770 let envelope = envelope.into_any();
771 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
772 let receipt = envelope.receipt();
773 peer.respond(receipt, proto::Ack {})?
774 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
775 {
776 peer.respond(envelope.receipt(), envelope.payload.clone())?
777 } else {
778 panic!("unknown message type");
779 }
780 }
781
782 Ok(())
783 }
784 }
785
786 #[gpui::test(iterations = 50)]
787 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
788 let executor = cx.executor();
789 let server = Peer::new(0);
790 let client = Peer::new(0);
791
792 let (client_to_server_conn, server_to_client_conn, _kill) =
793 Connection::in_memory(executor.clone());
794 let (client_to_server_conn_id, io_task1, mut client_incoming) =
795 client.add_test_connection(client_to_server_conn, executor.clone());
796
797 let (server_to_client_conn_id, io_task2, mut server_incoming) =
798 server.add_test_connection(server_to_client_conn, executor.clone());
799
800 executor.spawn(io_task1).detach();
801 executor.spawn(io_task2).detach();
802
803 executor
804 .spawn(async move {
805 let future = server_incoming.next().await;
806 let request = future
807 .unwrap()
808 .into_any()
809 .downcast::<TypedEnvelope<proto::Ping>>()
810 .unwrap();
811
812 server
813 .send(
814 server_to_client_conn_id,
815 ErrorCode::Internal
816 .message("message 1".to_string())
817 .to_proto(),
818 )
819 .unwrap();
820 server
821 .send(
822 server_to_client_conn_id,
823 ErrorCode::Internal
824 .message("message 2".to_string())
825 .to_proto(),
826 )
827 .unwrap();
828 server.respond(request.receipt(), proto::Ack {}).unwrap();
829
830 // Prevent the connection from being dropped
831 server_incoming.next().await;
832 })
833 .detach();
834
835 let events = Arc::new(Mutex::new(Vec::new()));
836
837 let response = client.request(client_to_server_conn_id, proto::Ping {});
838 let response_task = executor.spawn({
839 let events = events.clone();
840 async move {
841 response.await.unwrap();
842 events.lock().push("response".to_string());
843 }
844 });
845
846 executor
847 .spawn({
848 let events = events.clone();
849 async move {
850 let incoming1 = client_incoming
851 .next()
852 .await
853 .unwrap()
854 .into_any()
855 .downcast::<TypedEnvelope<proto::Error>>()
856 .unwrap();
857 events.lock().push(incoming1.payload.message);
858 let incoming2 = client_incoming
859 .next()
860 .await
861 .unwrap()
862 .into_any()
863 .downcast::<TypedEnvelope<proto::Error>>()
864 .unwrap();
865 events.lock().push(incoming2.payload.message);
866
867 // Prevent the connection from being dropped
868 client_incoming.next().await;
869 }
870 })
871 .detach();
872
873 response_task.await;
874 assert_eq!(
875 &*events.lock(),
876 &[
877 "message 1".to_string(),
878 "message 2".to_string(),
879 "response".to_string()
880 ]
881 );
882 }
883
884 #[gpui::test(iterations = 50)]
885 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
886 let executor = cx.executor();
887 let server = Peer::new(0);
888 let client = Peer::new(0);
889
890 let (client_to_server_conn, server_to_client_conn, _kill) =
891 Connection::in_memory(cx.executor());
892 let (client_to_server_conn_id, io_task1, mut client_incoming) =
893 client.add_test_connection(client_to_server_conn, cx.executor());
894 let (server_to_client_conn_id, io_task2, mut server_incoming) =
895 server.add_test_connection(server_to_client_conn, cx.executor());
896
897 executor.spawn(io_task1).detach();
898 executor.spawn(io_task2).detach();
899
900 executor
901 .spawn(async move {
902 let request1 = server_incoming
903 .next()
904 .await
905 .unwrap()
906 .into_any()
907 .downcast::<TypedEnvelope<proto::Ping>>()
908 .unwrap();
909 let request2 = server_incoming
910 .next()
911 .await
912 .unwrap()
913 .into_any()
914 .downcast::<TypedEnvelope<proto::Ping>>()
915 .unwrap();
916
917 server
918 .send(
919 server_to_client_conn_id,
920 ErrorCode::Internal
921 .message("message 1".to_string())
922 .to_proto(),
923 )
924 .unwrap();
925 server
926 .send(
927 server_to_client_conn_id,
928 ErrorCode::Internal
929 .message("message 2".to_string())
930 .to_proto(),
931 )
932 .unwrap();
933 server.respond(request1.receipt(), proto::Ack {}).unwrap();
934 server.respond(request2.receipt(), proto::Ack {}).unwrap();
935
936 // Prevent the connection from being dropped
937 server_incoming.next().await;
938 })
939 .detach();
940
941 let events = Arc::new(Mutex::new(Vec::new()));
942
943 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
944 let request1_task = executor.spawn(request1);
945 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
946 let request2_task = executor.spawn({
947 let events = events.clone();
948 async move {
949 request2.await.unwrap();
950 events.lock().push("response 2".to_string());
951 }
952 });
953
954 executor
955 .spawn({
956 let events = events.clone();
957 async move {
958 let incoming1 = client_incoming
959 .next()
960 .await
961 .unwrap()
962 .into_any()
963 .downcast::<TypedEnvelope<proto::Error>>()
964 .unwrap();
965 events.lock().push(incoming1.payload.message);
966 let incoming2 = client_incoming
967 .next()
968 .await
969 .unwrap()
970 .into_any()
971 .downcast::<TypedEnvelope<proto::Error>>()
972 .unwrap();
973 events.lock().push(incoming2.payload.message);
974
975 // Prevent the connection from being dropped
976 client_incoming.next().await;
977 }
978 })
979 .detach();
980
981 // Allow the request to make some progress before dropping it.
982 cx.executor().simulate_random_delay().await;
983 drop(request1_task);
984
985 request2_task.await;
986 assert_eq!(
987 &*events.lock(),
988 &[
989 "message 1".to_string(),
990 "message 2".to_string(),
991 "response 2".to_string()
992 ]
993 );
994 }
995
996 #[gpui::test(iterations = 50)]
997 async fn test_disconnect(cx: &mut TestAppContext) {
998 let executor = cx.executor();
999
1000 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1001
1002 let client = Peer::new(0);
1003 let (connection_id, io_handler, mut incoming) =
1004 client.add_test_connection(client_conn, executor.clone());
1005
1006 let (io_ended_tx, io_ended_rx) = oneshot::channel();
1007 executor
1008 .spawn(async move {
1009 io_handler.await.ok();
1010 io_ended_tx.send(()).unwrap();
1011 })
1012 .detach();
1013
1014 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
1015 executor
1016 .spawn(async move {
1017 incoming.next().await;
1018 messages_ended_tx.send(()).unwrap();
1019 })
1020 .detach();
1021
1022 client.disconnect(connection_id);
1023
1024 let _ = io_ended_rx.await;
1025 let _ = messages_ended_rx.await;
1026 assert!(
1027 server_conn
1028 .send(WebSocketMessage::Binary(vec![].into()))
1029 .await
1030 .is_err()
1031 );
1032 }
1033
1034 #[gpui::test(iterations = 50)]
1035 async fn test_io_error(cx: &mut TestAppContext) {
1036 let executor = cx.executor();
1037 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1038
1039 let client = Peer::new(0);
1040 let (connection_id, io_handler, mut incoming) =
1041 client.add_test_connection(client_conn, executor.clone());
1042 executor.spawn(io_handler).detach();
1043 executor
1044 .spawn(async move { incoming.next().await })
1045 .detach();
1046
1047 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
1048 let _request = server_conn.rx.next().await.unwrap().unwrap();
1049
1050 drop(server_conn);
1051 assert_eq!(
1052 response.await.unwrap_err().to_string(),
1053 "connection was closed"
1054 );
1055 }
1056}