1use super::{
2 Connection,
3 message_stream::{Message, MessageStream},
4 proto::{
5 self, AnyTypedEnvelope, EnvelopedMessage, PeerId, Receipt, RequestMessage, TypedEnvelope,
6 },
7};
8use anyhow::{Context as _, Result, anyhow};
9use collections::HashMap;
10use futures::{
11 FutureExt, SinkExt, Stream, StreamExt, TryFutureExt,
12 channel::{mpsc, oneshot},
13 stream::BoxStream,
14};
15use parking_lot::{Mutex, RwLock};
16use proto::{ErrorCode, ErrorCodeExt, ErrorExt, RpcError};
17use serde::{Serialize, ser::SerializeStruct};
18use std::{
19 fmt, future,
20 future::Future,
21 sync::atomic::Ordering::SeqCst,
22 sync::{
23 Arc,
24 atomic::{self, AtomicU32},
25 },
26 time::Duration,
27 time::Instant,
28};
29
30#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)]
31pub struct ConnectionId {
32 pub owner_id: u32,
33 pub id: u32,
34}
35
36impl From<ConnectionId> for PeerId {
37 fn from(id: ConnectionId) -> Self {
38 PeerId {
39 owner_id: id.owner_id,
40 id: id.id,
41 }
42 }
43}
44
45impl From<PeerId> for ConnectionId {
46 fn from(peer_id: PeerId) -> Self {
47 Self {
48 owner_id: peer_id.owner_id,
49 id: peer_id.id,
50 }
51 }
52}
53
54impl fmt::Display for ConnectionId {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 write!(f, "{}/{}", self.owner_id, self.id)
57 }
58}
59
60pub struct Peer {
61 epoch: AtomicU32,
62 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
63 next_connection_id: AtomicU32,
64}
65
66#[derive(Clone, Serialize)]
67pub struct ConnectionState {
68 #[serde(skip)]
69 outgoing_tx: mpsc::UnboundedSender<Message>,
70 next_message_id: Arc<AtomicU32>,
71 #[allow(clippy::type_complexity)]
72 #[serde(skip)]
73 response_channels: Arc<
74 Mutex<
75 Option<
76 HashMap<
77 u32,
78 oneshot::Sender<(proto::Envelope, std::time::Instant, oneshot::Sender<()>)>,
79 >,
80 >,
81 >,
82 >,
83 #[allow(clippy::type_complexity)]
84 #[serde(skip)]
85 stream_response_channels: Arc<
86 Mutex<
87 Option<
88 HashMap<u32, mpsc::UnboundedSender<(Result<proto::Envelope>, oneshot::Sender<()>)>>,
89 >,
90 >,
91 >,
92}
93
94const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
95const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
96pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(10);
97
98impl Peer {
99 pub fn new(epoch: u32) -> Arc<Self> {
100 Arc::new(Self {
101 epoch: AtomicU32::new(epoch),
102 connections: Default::default(),
103 next_connection_id: Default::default(),
104 })
105 }
106
107 pub fn epoch(&self) -> u32 {
108 self.epoch.load(SeqCst)
109 }
110
111 pub fn add_connection<F, Fut, Out>(
112 self: &Arc<Self>,
113 connection: Connection,
114 create_timer: F,
115 ) -> (
116 ConnectionId,
117 impl Future<Output = anyhow::Result<()>> + Send + use<F, Fut, Out>,
118 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
119 )
120 where
121 F: Send + Fn(Duration) -> Fut,
122 Fut: Send + Future<Output = Out>,
123 Out: Send,
124 {
125 // For outgoing messages, use an unbounded channel so that application code
126 // can always send messages without yielding. For incoming messages, use a
127 // bounded channel so that other peers will receive backpressure if they send
128 // messages faster than this peer can process them.
129 #[cfg(any(test, feature = "test-support"))]
130 const INCOMING_BUFFER_SIZE: usize = 1;
131 #[cfg(not(any(test, feature = "test-support")))]
132 const INCOMING_BUFFER_SIZE: usize = 256;
133 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
134 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
135
136 let connection_id = ConnectionId {
137 owner_id: self.epoch.load(SeqCst),
138 id: self.next_connection_id.fetch_add(1, SeqCst),
139 };
140 let connection_state = ConnectionState {
141 outgoing_tx,
142 next_message_id: Default::default(),
143 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
144 stream_response_channels: Arc::new(Mutex::new(Some(Default::default()))),
145 };
146 let mut writer = MessageStream::new(connection.tx);
147 let mut reader = MessageStream::new(connection.rx);
148
149 let this = self.clone();
150 let response_channels = connection_state.response_channels.clone();
151 let stream_response_channels = connection_state.stream_response_channels.clone();
152
153 let handle_io = async move {
154 tracing::trace!(%connection_id, "handle io future: start");
155
156 let _end_connection = util::defer(|| {
157 response_channels.lock().take();
158 if let Some(channels) = stream_response_channels.lock().take() {
159 for channel in channels.values() {
160 let _ = channel.unbounded_send((
161 Err(anyhow!("connection closed")),
162 oneshot::channel().0,
163 ));
164 }
165 }
166 this.connections.write().remove(&connection_id);
167 tracing::trace!(%connection_id, "handle io future: end");
168 });
169
170 // Send messages on this frequency so the connection isn't closed.
171 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
172 futures::pin_mut!(keepalive_timer);
173
174 // Disconnect if we don't receive messages at least this frequently.
175 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
176 futures::pin_mut!(receive_timeout);
177
178 loop {
179 tracing::trace!(%connection_id, "outer loop iteration start");
180 let read_message = reader.read().fuse();
181 futures::pin_mut!(read_message);
182
183 loop {
184 tracing::trace!(%connection_id, "inner loop iteration start");
185 futures::select_biased! {
186 outgoing = outgoing_rx.next().fuse() => match outgoing {
187 Some(outgoing) => {
188 tracing::trace!(%connection_id, "outgoing rpc message: writing");
189 futures::select_biased! {
190 result = writer.write(outgoing).fuse() => {
191 tracing::trace!(%connection_id, "outgoing rpc message: done writing");
192 result.context("failed to write RPC message")?;
193 tracing::trace!(%connection_id, "keepalive interval: resetting after sending message");
194 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
195 }
196 _ = create_timer(WRITE_TIMEOUT).fuse() => {
197 tracing::trace!(%connection_id, "outgoing rpc message: writing timed out");
198 anyhow::bail!("timed out writing message");
199 }
200 }
201 }
202 None => {
203 tracing::trace!(%connection_id, "outgoing rpc message: channel closed");
204 return Ok(())
205 },
206 },
207 _ = keepalive_timer => {
208 tracing::trace!(%connection_id, "keepalive interval: pinging");
209 futures::select_biased! {
210 result = writer.write(Message::Ping).fuse() => {
211 tracing::trace!(%connection_id, "keepalive interval: done pinging");
212 result.context("failed to send keepalive")?;
213 tracing::trace!(%connection_id, "keepalive interval: resetting after pinging");
214 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
215 }
216 _ = create_timer(WRITE_TIMEOUT).fuse() => {
217 tracing::trace!(%connection_id, "keepalive interval: pinging timed out");
218 anyhow::bail!("timed out sending keepalive");
219 }
220 }
221 }
222 incoming = read_message => {
223 let incoming = incoming.context("error reading rpc message from socket")?;
224 tracing::trace!(%connection_id, "incoming rpc message: received");
225 tracing::trace!(%connection_id, "receive timeout: resetting");
226 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
227 if let (Message::Envelope(incoming), received_at) = incoming {
228 tracing::trace!(%connection_id, "incoming rpc message: processing");
229 futures::select_biased! {
230 result = incoming_tx.send((incoming, received_at)).fuse() => match result {
231 Ok(_) => {
232 tracing::trace!(%connection_id, "incoming rpc message: processed");
233 }
234 Err(_) => {
235 tracing::trace!(%connection_id, "incoming rpc message: channel closed");
236 return Ok(())
237 }
238 },
239 _ = create_timer(WRITE_TIMEOUT).fuse() => {
240 tracing::trace!(%connection_id, "incoming rpc message: processing timed out");
241 anyhow::bail!("timed out processing incoming message");
242 }
243 }
244 }
245 break;
246 },
247 _ = receive_timeout => {
248 tracing::trace!(%connection_id, "receive timeout: delay between messages too long");
249 anyhow::bail!("delay between messages too long");
250 }
251 }
252 }
253 }
254 };
255
256 let response_channels = connection_state.response_channels.clone();
257 let stream_response_channels = connection_state.stream_response_channels.clone();
258 self.connections
259 .write()
260 .insert(connection_id, connection_state);
261
262 let incoming_rx = incoming_rx.filter_map(move |(incoming, received_at)| {
263 let response_channels = response_channels.clone();
264 let stream_response_channels = stream_response_channels.clone();
265 async move {
266 let message_id = incoming.id;
267 tracing::trace!(?incoming, "incoming message future: start");
268 let _end = util::defer(move || {
269 tracing::trace!(%connection_id, message_id, "incoming message future: end");
270 });
271
272 if let Some(responding_to) = incoming.responding_to {
273 tracing::trace!(
274 %connection_id,
275 message_id,
276 responding_to,
277 "incoming response: received"
278 );
279 let response_channel =
280 response_channels.lock().as_mut()?.remove(&responding_to);
281 let stream_response_channel = stream_response_channels
282 .lock()
283 .as_ref()?
284 .get(&responding_to)
285 .cloned();
286
287 if let Some(tx) = response_channel {
288 let requester_resumed = oneshot::channel();
289 if let Err(error) = tx.send((incoming, received_at, requester_resumed.0)) {
290 tracing::trace!(
291 %connection_id,
292 message_id,
293 responding_to = responding_to,
294 ?error,
295 "incoming response: request future dropped",
296 );
297 }
298
299 tracing::trace!(
300 %connection_id,
301 message_id,
302 responding_to,
303 "incoming response: waiting to resume requester"
304 );
305 let _ = requester_resumed.1.await;
306 tracing::trace!(
307 %connection_id,
308 message_id,
309 responding_to,
310 "incoming response: requester resumed"
311 );
312 } else if let Some(tx) = stream_response_channel {
313 let requester_resumed = oneshot::channel();
314 if let Err(error) = tx.unbounded_send((Ok(incoming), requester_resumed.0)) {
315 tracing::debug!(
316 %connection_id,
317 message_id,
318 responding_to = responding_to,
319 ?error,
320 "incoming stream response: request future dropped",
321 );
322 }
323
324 tracing::debug!(
325 %connection_id,
326 message_id,
327 responding_to,
328 "incoming stream response: waiting to resume requester"
329 );
330 let _ = requester_resumed.1.await;
331 tracing::debug!(
332 %connection_id,
333 message_id,
334 responding_to,
335 "incoming stream response: requester resumed"
336 );
337 } else {
338 let message_type = proto::build_typed_envelope(
339 connection_id.into(),
340 received_at,
341 incoming,
342 )
343 .map(|p| p.payload_type_name());
344 tracing::warn!(
345 %connection_id,
346 message_id,
347 responding_to,
348 message_type,
349 "incoming response: unknown request"
350 );
351 }
352
353 None
354 } else {
355 tracing::trace!(%connection_id, message_id, "incoming message: received");
356 proto::build_typed_envelope(connection_id.into(), received_at, incoming)
357 .or_else(|| {
358 tracing::error!(
359 %connection_id,
360 message_id,
361 "unable to construct a typed envelope"
362 );
363 None
364 })
365 }
366 }
367 });
368 (connection_id, handle_io, incoming_rx.boxed())
369 }
370
371 #[cfg(any(test, feature = "test-support"))]
372 pub fn add_test_connection(
373 self: &Arc<Self>,
374 connection: Connection,
375 executor: gpui::BackgroundExecutor,
376 ) -> (
377 ConnectionId,
378 impl Future<Output = anyhow::Result<()>> + Send + use<>,
379 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
380 ) {
381 self.add_connection(connection, move |duration| executor.timer(duration))
382 }
383
384 pub fn disconnect(&self, connection_id: ConnectionId) {
385 self.connections.write().remove(&connection_id);
386 }
387
388 #[cfg(any(test, feature = "test-support"))]
389 pub fn reset(&self, epoch: u32) {
390 self.next_connection_id.store(0, SeqCst);
391 self.epoch.store(epoch, SeqCst);
392 }
393
394 pub fn teardown(&self) {
395 self.connections.write().clear();
396 }
397
398 /// Make a request and wait for a response.
399 pub fn request<T: RequestMessage>(
400 &self,
401 receiver_id: ConnectionId,
402 request: T,
403 ) -> impl Future<Output = Result<T::Response>> + use<T> {
404 self.request_internal(None, receiver_id, request)
405 .map_ok(|envelope| envelope.payload)
406 }
407
408 pub fn request_envelope<T: RequestMessage>(
409 &self,
410 receiver_id: ConnectionId,
411 request: T,
412 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
413 self.request_internal(None, receiver_id, request)
414 }
415
416 pub fn forward_request<T: RequestMessage>(
417 &self,
418 sender_id: ConnectionId,
419 receiver_id: ConnectionId,
420 request: T,
421 ) -> impl Future<Output = Result<T::Response>> {
422 self.request_internal(Some(sender_id), receiver_id, request)
423 .map_ok(|envelope| envelope.payload)
424 }
425
426 fn request_internal<T: RequestMessage>(
427 &self,
428 original_sender_id: Option<ConnectionId>,
429 receiver_id: ConnectionId,
430 request: T,
431 ) -> impl Future<Output = Result<TypedEnvelope<T::Response>>> + use<T> {
432 let envelope = request.into_envelope(0, None, original_sender_id.map(Into::into));
433 let response = self.request_dynamic(receiver_id, envelope, T::NAME);
434 async move {
435 let (response, received_at) = response.await?;
436 Ok(TypedEnvelope {
437 message_id: response.id,
438 sender_id: receiver_id.into(),
439 original_sender_id: response.original_sender_id,
440 payload: T::Response::from_envelope(response)
441 .context("received response of the wrong type")?,
442 received_at,
443 })
444 }
445 }
446
447 /// Make a request and wait for a response.
448 ///
449 /// The caller must make sure to deserialize the response into the request's
450 /// response type. This interface is only useful in trait objects, where
451 /// generics can't be used. If you have a concrete type, use `request`.
452 pub fn request_dynamic(
453 &self,
454 receiver_id: ConnectionId,
455 mut envelope: proto::Envelope,
456 type_name: &'static str,
457 ) -> impl Future<Output = Result<(proto::Envelope, Instant)>> + use<> {
458 let (tx, rx) = oneshot::channel();
459 let send = self.connection_state(receiver_id).and_then(|connection| {
460 envelope.id = connection.next_message_id.fetch_add(1, SeqCst);
461 connection
462 .response_channels
463 .lock()
464 .as_mut()
465 .context("connection was closed")?
466 .insert(envelope.id, tx);
467 connection
468 .outgoing_tx
469 .unbounded_send(Message::Envelope(envelope))
470 .context("connection was closed")?;
471 Ok(())
472 });
473 async move {
474 send?;
475 let (response, received_at, _barrier) = rx.await.context("connection was closed")?;
476 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
477 return Err(RpcError::from_proto(error, type_name));
478 }
479 Ok((response, received_at))
480 }
481 }
482
483 pub fn request_stream<T: RequestMessage>(
484 &self,
485 receiver_id: ConnectionId,
486 request: T,
487 ) -> impl Future<Output = Result<impl Unpin + Stream<Item = Result<T::Response>>>> {
488 let (tx, rx) = mpsc::unbounded();
489 let send = self.connection_state(receiver_id).and_then(|connection| {
490 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
491 let stream_response_channels = connection.stream_response_channels.clone();
492 stream_response_channels
493 .lock()
494 .as_mut()
495 .context("connection was closed")?
496 .insert(message_id, tx);
497 connection
498 .outgoing_tx
499 .unbounded_send(Message::Envelope(
500 request.into_envelope(message_id, None, None),
501 ))
502 .context("connection was closed")?;
503 Ok((message_id, stream_response_channels))
504 });
505
506 async move {
507 let (message_id, stream_response_channels) = send?;
508 let stream_response_channels = Arc::downgrade(&stream_response_channels);
509
510 Ok(rx.filter_map(move |(response, _barrier)| {
511 let stream_response_channels = stream_response_channels.clone();
512 future::ready(match response {
513 Ok(response) => {
514 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
515 Some(Err(RpcError::from_proto(error, T::NAME)))
516 } else if let Some(proto::envelope::Payload::EndStream(_)) =
517 &response.payload
518 {
519 // Remove the transmitting end of the response channel to end the stream.
520 if let Some(channels) = stream_response_channels.upgrade()
521 && let Some(channels) = channels.lock().as_mut()
522 {
523 channels.remove(&message_id);
524 }
525 None
526 } else {
527 Some(
528 T::Response::from_envelope(response)
529 .context("received response of the wrong type"),
530 )
531 }
532 }
533 Err(error) => Some(Err(error)),
534 })
535 }))
536 }
537 }
538
539 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
540 let connection = self.connection_state(receiver_id)?;
541 let message_id = connection
542 .next_message_id
543 .fetch_add(1, atomic::Ordering::SeqCst);
544 connection.outgoing_tx.unbounded_send(Message::Envelope(
545 message.into_envelope(message_id, None, None),
546 ))?;
547 Ok(())
548 }
549
550 pub fn send_dynamic(&self, receiver_id: ConnectionId, message: proto::Envelope) -> Result<()> {
551 let connection = self.connection_state(receiver_id)?;
552 connection
553 .outgoing_tx
554 .unbounded_send(Message::Envelope(message))?;
555 Ok(())
556 }
557
558 pub fn forward_send<T: EnvelopedMessage>(
559 &self,
560 sender_id: ConnectionId,
561 receiver_id: ConnectionId,
562 message: T,
563 ) -> Result<()> {
564 let connection = self.connection_state(receiver_id)?;
565 let message_id = connection
566 .next_message_id
567 .fetch_add(1, atomic::Ordering::SeqCst);
568 connection
569 .outgoing_tx
570 .unbounded_send(Message::Envelope(message.into_envelope(
571 message_id,
572 None,
573 Some(sender_id.into()),
574 )))?;
575 Ok(())
576 }
577
578 pub fn respond<T: RequestMessage>(
579 &self,
580 receipt: Receipt<T>,
581 response: T::Response,
582 ) -> Result<()> {
583 let connection = self.connection_state(receipt.sender_id.into())?;
584 let message_id = connection
585 .next_message_id
586 .fetch_add(1, atomic::Ordering::SeqCst);
587 connection
588 .outgoing_tx
589 .unbounded_send(Message::Envelope(response.into_envelope(
590 message_id,
591 Some(receipt.message_id),
592 None,
593 )))?;
594 Ok(())
595 }
596
597 pub fn end_stream<T: RequestMessage>(&self, receipt: Receipt<T>) -> Result<()> {
598 let connection = self.connection_state(receipt.sender_id.into())?;
599 let message_id = connection
600 .next_message_id
601 .fetch_add(1, atomic::Ordering::SeqCst);
602
603 let message = proto::EndStream {};
604
605 connection
606 .outgoing_tx
607 .unbounded_send(Message::Envelope(message.into_envelope(
608 message_id,
609 Some(receipt.message_id),
610 None,
611 )))?;
612 Ok(())
613 }
614
615 pub fn respond_with_error<T: RequestMessage>(
616 &self,
617 receipt: Receipt<T>,
618 response: proto::Error,
619 ) -> Result<()> {
620 let connection = self.connection_state(receipt.sender_id.into())?;
621 let message_id = connection
622 .next_message_id
623 .fetch_add(1, atomic::Ordering::SeqCst);
624 connection
625 .outgoing_tx
626 .unbounded_send(Message::Envelope(response.into_envelope(
627 message_id,
628 Some(receipt.message_id),
629 None,
630 )))?;
631 Ok(())
632 }
633
634 pub fn respond_with_unhandled_message(
635 &self,
636 sender_id: ConnectionId,
637 request_message_id: u32,
638 message_type_name: &'static str,
639 ) -> Result<()> {
640 let connection = self.connection_state(sender_id)?;
641 let response = ErrorCode::Internal
642 .message(format!("message {} was not handled", message_type_name))
643 .to_proto();
644 let message_id = connection
645 .next_message_id
646 .fetch_add(1, atomic::Ordering::SeqCst);
647 connection
648 .outgoing_tx
649 .unbounded_send(Message::Envelope(response.into_envelope(
650 message_id,
651 Some(request_message_id),
652 None,
653 )))?;
654 Ok(())
655 }
656
657 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
658 let connections = self.connections.read();
659 let connection = connections
660 .get(&connection_id)
661 .with_context(|| format!("no such connection: {connection_id}"))?;
662 Ok(connection.clone())
663 }
664}
665
666impl Serialize for Peer {
667 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
668 where
669 S: serde::Serializer,
670 {
671 let mut state = serializer.serialize_struct("Peer", 2)?;
672 state.serialize_field("connections", &*self.connections.read())?;
673 state.end()
674 }
675}
676
677#[cfg(test)]
678mod tests {
679 use super::*;
680 use async_tungstenite::tungstenite::Message as WebSocketMessage;
681 use gpui::TestAppContext;
682
683 fn init_logger() {
684 zlog::init_test();
685 }
686
687 #[gpui::test(iterations = 50)]
688 async fn test_request_response(cx: &mut TestAppContext) {
689 init_logger();
690
691 let executor = cx.executor();
692
693 // create 2 clients connected to 1 server
694 let server = Peer::new(0);
695 let client1 = Peer::new(0);
696 let client2 = Peer::new(0);
697
698 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
699 Connection::in_memory(cx.executor());
700 let (client1_conn_id, io_task1, client1_incoming) =
701 client1.add_test_connection(client1_to_server_conn, cx.executor());
702 let (_, io_task2, server_incoming1) =
703 server.add_test_connection(server_to_client_1_conn, cx.executor());
704
705 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
706 Connection::in_memory(cx.executor());
707 let (client2_conn_id, io_task3, client2_incoming) =
708 client2.add_test_connection(client2_to_server_conn, cx.executor());
709 let (_, io_task4, server_incoming2) =
710 server.add_test_connection(server_to_client_2_conn, cx.executor());
711
712 executor.spawn(io_task1).detach();
713 executor.spawn(io_task2).detach();
714 executor.spawn(io_task3).detach();
715 executor.spawn(io_task4).detach();
716 executor
717 .spawn(handle_messages(server_incoming1, server.clone()))
718 .detach();
719 executor
720 .spawn(handle_messages(client1_incoming, client1.clone()))
721 .detach();
722 executor
723 .spawn(handle_messages(server_incoming2, server.clone()))
724 .detach();
725 executor
726 .spawn(handle_messages(client2_incoming, client2.clone()))
727 .detach();
728
729 assert_eq!(
730 client1
731 .request(client1_conn_id, proto::Ping {},)
732 .await
733 .unwrap(),
734 proto::Ack {}
735 );
736
737 assert_eq!(
738 client2
739 .request(client2_conn_id, proto::Ping {},)
740 .await
741 .unwrap(),
742 proto::Ack {}
743 );
744
745 assert_eq!(
746 client1
747 .request(client1_conn_id, proto::Test { id: 1 },)
748 .await
749 .unwrap(),
750 proto::Test { id: 1 }
751 );
752
753 assert_eq!(
754 client2
755 .request(client2_conn_id, proto::Test { id: 2 })
756 .await
757 .unwrap(),
758 proto::Test { id: 2 }
759 );
760
761 client1.disconnect(client1_conn_id);
762 client2.disconnect(client1_conn_id);
763
764 async fn handle_messages(
765 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
766 peer: Arc<Peer>,
767 ) -> Result<()> {
768 while let Some(envelope) = messages.next().await {
769 let envelope = envelope.into_any();
770 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
771 let receipt = envelope.receipt();
772 peer.respond(receipt, proto::Ack {})?
773 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
774 {
775 peer.respond(envelope.receipt(), envelope.payload.clone())?
776 } else {
777 panic!("unknown message type");
778 }
779 }
780
781 Ok(())
782 }
783 }
784
785 #[gpui::test(iterations = 50)]
786 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
787 let executor = cx.executor();
788 let server = Peer::new(0);
789 let client = Peer::new(0);
790
791 let (client_to_server_conn, server_to_client_conn, _kill) =
792 Connection::in_memory(executor.clone());
793 let (client_to_server_conn_id, io_task1, mut client_incoming) =
794 client.add_test_connection(client_to_server_conn, executor.clone());
795
796 let (server_to_client_conn_id, io_task2, mut server_incoming) =
797 server.add_test_connection(server_to_client_conn, executor.clone());
798
799 executor.spawn(io_task1).detach();
800 executor.spawn(io_task2).detach();
801
802 executor
803 .spawn(async move {
804 let future = server_incoming.next().await;
805 let request = future
806 .unwrap()
807 .into_any()
808 .downcast::<TypedEnvelope<proto::Ping>>()
809 .unwrap();
810
811 server
812 .send(
813 server_to_client_conn_id,
814 ErrorCode::Internal
815 .message("message 1".to_string())
816 .to_proto(),
817 )
818 .unwrap();
819 server
820 .send(
821 server_to_client_conn_id,
822 ErrorCode::Internal
823 .message("message 2".to_string())
824 .to_proto(),
825 )
826 .unwrap();
827 server.respond(request.receipt(), proto::Ack {}).unwrap();
828
829 // Prevent the connection from being dropped
830 server_incoming.next().await;
831 })
832 .detach();
833
834 let events = Arc::new(Mutex::new(Vec::new()));
835
836 let response = client.request(client_to_server_conn_id, proto::Ping {});
837 let response_task = executor.spawn({
838 let events = events.clone();
839 async move {
840 response.await.unwrap();
841 events.lock().push("response".to_string());
842 }
843 });
844
845 executor
846 .spawn({
847 let events = events.clone();
848 async move {
849 let incoming1 = client_incoming
850 .next()
851 .await
852 .unwrap()
853 .into_any()
854 .downcast::<TypedEnvelope<proto::Error>>()
855 .unwrap();
856 events.lock().push(incoming1.payload.message);
857 let incoming2 = client_incoming
858 .next()
859 .await
860 .unwrap()
861 .into_any()
862 .downcast::<TypedEnvelope<proto::Error>>()
863 .unwrap();
864 events.lock().push(incoming2.payload.message);
865
866 // Prevent the connection from being dropped
867 client_incoming.next().await;
868 }
869 })
870 .detach();
871
872 response_task.await;
873 assert_eq!(
874 &*events.lock(),
875 &[
876 "message 1".to_string(),
877 "message 2".to_string(),
878 "response".to_string()
879 ]
880 );
881 }
882
883 #[gpui::test(iterations = 50)]
884 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
885 let executor = cx.executor();
886 let server = Peer::new(0);
887 let client = Peer::new(0);
888
889 let (client_to_server_conn, server_to_client_conn, _kill) =
890 Connection::in_memory(cx.executor());
891 let (client_to_server_conn_id, io_task1, mut client_incoming) =
892 client.add_test_connection(client_to_server_conn, cx.executor());
893 let (server_to_client_conn_id, io_task2, mut server_incoming) =
894 server.add_test_connection(server_to_client_conn, cx.executor());
895
896 executor.spawn(io_task1).detach();
897 executor.spawn(io_task2).detach();
898
899 executor
900 .spawn(async move {
901 let request1 = server_incoming
902 .next()
903 .await
904 .unwrap()
905 .into_any()
906 .downcast::<TypedEnvelope<proto::Ping>>()
907 .unwrap();
908 let request2 = server_incoming
909 .next()
910 .await
911 .unwrap()
912 .into_any()
913 .downcast::<TypedEnvelope<proto::Ping>>()
914 .unwrap();
915
916 server
917 .send(
918 server_to_client_conn_id,
919 ErrorCode::Internal
920 .message("message 1".to_string())
921 .to_proto(),
922 )
923 .unwrap();
924 server
925 .send(
926 server_to_client_conn_id,
927 ErrorCode::Internal
928 .message("message 2".to_string())
929 .to_proto(),
930 )
931 .unwrap();
932 server.respond(request1.receipt(), proto::Ack {}).unwrap();
933 server.respond(request2.receipt(), proto::Ack {}).unwrap();
934
935 // Prevent the connection from being dropped
936 server_incoming.next().await;
937 })
938 .detach();
939
940 let events = Arc::new(Mutex::new(Vec::new()));
941
942 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
943 let request1_task = executor.spawn(request1);
944 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
945 let request2_task = executor.spawn({
946 let events = events.clone();
947 async move {
948 request2.await.unwrap();
949 events.lock().push("response 2".to_string());
950 }
951 });
952
953 executor
954 .spawn({
955 let events = events.clone();
956 async move {
957 let incoming1 = client_incoming
958 .next()
959 .await
960 .unwrap()
961 .into_any()
962 .downcast::<TypedEnvelope<proto::Error>>()
963 .unwrap();
964 events.lock().push(incoming1.payload.message);
965 let incoming2 = client_incoming
966 .next()
967 .await
968 .unwrap()
969 .into_any()
970 .downcast::<TypedEnvelope<proto::Error>>()
971 .unwrap();
972 events.lock().push(incoming2.payload.message);
973
974 // Prevent the connection from being dropped
975 client_incoming.next().await;
976 }
977 })
978 .detach();
979
980 // Allow the request to make some progress before dropping it.
981 cx.executor().simulate_random_delay().await;
982 drop(request1_task);
983
984 request2_task.await;
985 assert_eq!(
986 &*events.lock(),
987 &[
988 "message 1".to_string(),
989 "message 2".to_string(),
990 "response 2".to_string()
991 ]
992 );
993 }
994
995 #[gpui::test(iterations = 50)]
996 async fn test_disconnect(cx: &mut TestAppContext) {
997 let executor = cx.executor();
998
999 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1000
1001 let client = Peer::new(0);
1002 let (connection_id, io_handler, mut incoming) =
1003 client.add_test_connection(client_conn, executor.clone());
1004
1005 let (io_ended_tx, io_ended_rx) = oneshot::channel();
1006 executor
1007 .spawn(async move {
1008 io_handler.await.ok();
1009 io_ended_tx.send(()).unwrap();
1010 })
1011 .detach();
1012
1013 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
1014 executor
1015 .spawn(async move {
1016 incoming.next().await;
1017 messages_ended_tx.send(()).unwrap();
1018 })
1019 .detach();
1020
1021 client.disconnect(connection_id);
1022
1023 let _ = io_ended_rx.await;
1024 let _ = messages_ended_rx.await;
1025 assert!(
1026 server_conn
1027 .send(WebSocketMessage::Binary(vec![].into()))
1028 .await
1029 .is_err()
1030 );
1031 }
1032
1033 #[gpui::test(iterations = 50)]
1034 async fn test_io_error(cx: &mut TestAppContext) {
1035 let executor = cx.executor();
1036 let (client_conn, mut server_conn, _kill) = Connection::in_memory(executor.clone());
1037
1038 let client = Peer::new(0);
1039 let (connection_id, io_handler, mut incoming) =
1040 client.add_test_connection(client_conn, executor.clone());
1041 executor.spawn(io_handler).detach();
1042 executor
1043 .spawn(async move { incoming.next().await })
1044 .detach();
1045
1046 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
1047 let _request = server_conn.rx.next().await.unwrap().unwrap();
1048
1049 drop(server_conn);
1050 assert_eq!(
1051 response.await.unwrap_err().to_string(),
1052 "connection was closed"
1053 );
1054 }
1055}