1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{channel::oneshot, stream::BoxStream, FutureExt as _, StreamExt};
8use parking_lot::{Mutex, RwLock};
9use postage::{
10 barrier, mpsc,
11 prelude::{Sink as _, Stream as _},
12};
13use smol_timeout::TimeoutExt as _;
14use std::sync::atomic::Ordering::SeqCst;
15use std::{
16 fmt,
17 future::Future,
18 marker::PhantomData,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23 time::Duration,
24};
25
26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub struct ConnectionId(pub u32);
28
29impl fmt::Display for ConnectionId {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 self.0.fmt(f)
32 }
33}
34
35#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
36pub struct PeerId(pub u32);
37
38impl fmt::Display for PeerId {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 self.0.fmt(f)
41 }
42}
43
44pub struct Receipt<T> {
45 pub sender_id: ConnectionId,
46 pub message_id: u32,
47 payload_type: PhantomData<T>,
48}
49
50impl<T> Clone for Receipt<T> {
51 fn clone(&self) -> Self {
52 Self {
53 sender_id: self.sender_id,
54 message_id: self.message_id,
55 payload_type: PhantomData,
56 }
57 }
58}
59
60impl<T> Copy for Receipt<T> {}
61
62pub struct TypedEnvelope<T> {
63 pub sender_id: ConnectionId,
64 pub original_sender_id: Option<PeerId>,
65 pub message_id: u32,
66 pub payload: T,
67}
68
69impl<T> TypedEnvelope<T> {
70 pub fn original_sender_id(&self) -> Result<PeerId> {
71 self.original_sender_id
72 .ok_or_else(|| anyhow!("missing original_sender_id"))
73 }
74}
75
76impl<T: RequestMessage> TypedEnvelope<T> {
77 pub fn receipt(&self) -> Receipt<T> {
78 Receipt {
79 sender_id: self.sender_id,
80 message_id: self.message_id,
81 payload_type: PhantomData,
82 }
83 }
84}
85
86pub struct Peer {
87 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
88 next_connection_id: AtomicU32,
89}
90
91#[derive(Clone)]
92pub struct ConnectionState {
93 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Message>,
94 next_message_id: Arc<AtomicU32>,
95 response_channels:
96 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
97}
98
99const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
100const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
101pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
102
103impl Peer {
104 pub fn new() -> Arc<Self> {
105 Arc::new(Self {
106 connections: Default::default(),
107 next_connection_id: Default::default(),
108 })
109 }
110
111 pub async fn add_connection<F, Fut, Out>(
112 self: &Arc<Self>,
113 connection: Connection,
114 create_timer: F,
115 ) -> (
116 ConnectionId,
117 impl Future<Output = anyhow::Result<()>> + Send,
118 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
119 )
120 where
121 F: Send + Fn(Duration) -> Fut,
122 Fut: Send + Future<Output = Out>,
123 Out: Send,
124 {
125 // For outgoing messages, use an unbounded channel so that application code
126 // can always send messages without yielding. For incoming messages, use a
127 // bounded channel so that other peers will receive backpressure if they send
128 // messages faster than this peer can process them.
129 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
130 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
131
132 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
133 let connection_state = ConnectionState {
134 outgoing_tx: outgoing_tx.clone(),
135 next_message_id: Default::default(),
136 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
137 };
138 let mut writer = MessageStream::new(connection.tx);
139 let mut reader = MessageStream::new(connection.rx);
140
141 let this = self.clone();
142 let response_channels = connection_state.response_channels.clone();
143 let handle_io = async move {
144 let _end_connection = util::defer(|| {
145 response_channels.lock().take();
146 this.connections.write().remove(&connection_id);
147 });
148
149 // Send messages on this frequency so the connection isn't closed.
150 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
151 futures::pin_mut!(keepalive_timer);
152
153 // Disconnect if we don't receive messages at least this frequently.
154 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
155 futures::pin_mut!(receive_timeout);
156
157 loop {
158 let read_message = reader.read().fuse();
159 futures::pin_mut!(read_message);
160
161 loop {
162 futures::select_biased! {
163 outgoing = outgoing_rx.next().fuse() => match outgoing {
164 Some(outgoing) => {
165 if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
166 result.context("failed to write RPC message")?;
167 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
168 } else {
169 Err(anyhow!("timed out writing message"))?;
170 }
171 }
172 None => return Ok(()),
173 },
174 incoming = read_message => {
175 let incoming = incoming.context("received invalid RPC message")?;
176 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
177 if let proto::Message::Envelope(incoming) = incoming {
178 if incoming_tx.send(incoming).await.is_err() {
179 return Ok(());
180 }
181 }
182 break;
183 },
184 _ = keepalive_timer => {
185 if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
186 result.context("failed to send keepalive")?;
187 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
188 } else {
189 Err(anyhow!("timed out sending keepalive"))?;
190 }
191 }
192 _ = receive_timeout => {
193 Err(anyhow!("delay between messages too long"))?
194 }
195 }
196 }
197 }
198 };
199
200 let response_channels = connection_state.response_channels.clone();
201 self.connections
202 .write()
203 .insert(connection_id, connection_state);
204
205 let incoming_rx = incoming_rx.filter_map(move |incoming| {
206 let response_channels = response_channels.clone();
207 async move {
208 if let Some(responding_to) = incoming.responding_to {
209 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
210 if let Some(tx) = channel {
211 let mut requester_resumed = barrier::channel();
212 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
213 log::debug!(
214 "received RPC but request future was dropped {:?}",
215 error.0
216 );
217 }
218 requester_resumed.1.recv().await;
219 } else {
220 log::warn!("received RPC response to unknown request {}", responding_to);
221 }
222
223 None
224 } else {
225 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
226 log::error!("unable to construct a typed envelope");
227 None
228 })
229 }
230 }
231 });
232 (connection_id, handle_io, incoming_rx.boxed())
233 }
234
235 #[cfg(any(test, feature = "test-support"))]
236 pub async fn add_test_connection(
237 self: &Arc<Self>,
238 connection: Connection,
239 executor: Arc<gpui::executor::Background>,
240 ) -> (
241 ConnectionId,
242 impl Future<Output = anyhow::Result<()>> + Send,
243 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
244 ) {
245 let executor = executor.clone();
246 self.add_connection(connection, move |duration| executor.timer(duration))
247 .await
248 }
249
250 pub fn disconnect(&self, connection_id: ConnectionId) {
251 self.connections.write().remove(&connection_id);
252 }
253
254 pub fn reset(&self) {
255 self.connections.write().clear();
256 }
257
258 pub fn request<T: RequestMessage>(
259 &self,
260 receiver_id: ConnectionId,
261 request: T,
262 ) -> impl Future<Output = Result<T::Response>> {
263 self.request_internal(None, receiver_id, request)
264 }
265
266 pub fn forward_request<T: RequestMessage>(
267 &self,
268 sender_id: ConnectionId,
269 receiver_id: ConnectionId,
270 request: T,
271 ) -> impl Future<Output = Result<T::Response>> {
272 self.request_internal(Some(sender_id), receiver_id, request)
273 }
274
275 pub fn request_internal<T: RequestMessage>(
276 &self,
277 original_sender_id: Option<ConnectionId>,
278 receiver_id: ConnectionId,
279 request: T,
280 ) -> impl Future<Output = Result<T::Response>> {
281 let (tx, rx) = oneshot::channel();
282 let send = self.connection_state(receiver_id).and_then(|connection| {
283 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
284 connection
285 .response_channels
286 .lock()
287 .as_mut()
288 .ok_or_else(|| anyhow!("connection was closed"))?
289 .insert(message_id, tx);
290 connection
291 .outgoing_tx
292 .unbounded_send(proto::Message::Envelope(request.into_envelope(
293 message_id,
294 None,
295 original_sender_id.map(|id| id.0),
296 )))
297 .map_err(|_| anyhow!("connection was closed"))?;
298 Ok(())
299 });
300 async move {
301 send?;
302 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
303 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
304 Err(anyhow!("RPC request failed - {}", error.message))
305 } else {
306 T::Response::from_envelope(response)
307 .ok_or_else(|| anyhow!("received response of the wrong type"))
308 }
309 }
310 }
311
312 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
313 let connection = self.connection_state(receiver_id)?;
314 let message_id = connection
315 .next_message_id
316 .fetch_add(1, atomic::Ordering::SeqCst);
317 connection
318 .outgoing_tx
319 .unbounded_send(proto::Message::Envelope(
320 message.into_envelope(message_id, None, None),
321 ))?;
322 Ok(())
323 }
324
325 pub fn forward_send<T: EnvelopedMessage>(
326 &self,
327 sender_id: ConnectionId,
328 receiver_id: ConnectionId,
329 message: T,
330 ) -> Result<()> {
331 let connection = self.connection_state(receiver_id)?;
332 let message_id = connection
333 .next_message_id
334 .fetch_add(1, atomic::Ordering::SeqCst);
335 connection
336 .outgoing_tx
337 .unbounded_send(proto::Message::Envelope(message.into_envelope(
338 message_id,
339 None,
340 Some(sender_id.0),
341 )))?;
342 Ok(())
343 }
344
345 pub fn respond<T: RequestMessage>(
346 &self,
347 receipt: Receipt<T>,
348 response: T::Response,
349 ) -> Result<()> {
350 let connection = self.connection_state(receipt.sender_id)?;
351 let message_id = connection
352 .next_message_id
353 .fetch_add(1, atomic::Ordering::SeqCst);
354 connection
355 .outgoing_tx
356 .unbounded_send(proto::Message::Envelope(response.into_envelope(
357 message_id,
358 Some(receipt.message_id),
359 None,
360 )))?;
361 Ok(())
362 }
363
364 pub fn respond_with_error<T: RequestMessage>(
365 &self,
366 receipt: Receipt<T>,
367 response: proto::Error,
368 ) -> Result<()> {
369 let connection = self.connection_state(receipt.sender_id)?;
370 let message_id = connection
371 .next_message_id
372 .fetch_add(1, atomic::Ordering::SeqCst);
373 connection
374 .outgoing_tx
375 .unbounded_send(proto::Message::Envelope(response.into_envelope(
376 message_id,
377 Some(receipt.message_id),
378 None,
379 )))?;
380 Ok(())
381 }
382
383 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
384 let connections = self.connections.read();
385 let connection = connections
386 .get(&connection_id)
387 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
388 Ok(connection.clone())
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use crate::TypedEnvelope;
396 use async_tungstenite::tungstenite::Message as WebSocketMessage;
397 use gpui::TestAppContext;
398
399 #[gpui::test(iterations = 50)]
400 async fn test_request_response(cx: &mut TestAppContext) {
401 let executor = cx.foreground();
402
403 // create 2 clients connected to 1 server
404 let server = Peer::new();
405 let client1 = Peer::new();
406 let client2 = Peer::new();
407
408 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
409 Connection::in_memory(cx.background());
410 let (client1_conn_id, io_task1, client1_incoming) = client1
411 .add_test_connection(client1_to_server_conn, cx.background())
412 .await;
413 let (_, io_task2, server_incoming1) = server
414 .add_test_connection(server_to_client_1_conn, cx.background())
415 .await;
416
417 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
418 Connection::in_memory(cx.background());
419 let (client2_conn_id, io_task3, client2_incoming) = client2
420 .add_test_connection(client2_to_server_conn, cx.background())
421 .await;
422 let (_, io_task4, server_incoming2) = server
423 .add_test_connection(server_to_client_2_conn, cx.background())
424 .await;
425
426 executor.spawn(io_task1).detach();
427 executor.spawn(io_task2).detach();
428 executor.spawn(io_task3).detach();
429 executor.spawn(io_task4).detach();
430 executor
431 .spawn(handle_messages(server_incoming1, server.clone()))
432 .detach();
433 executor
434 .spawn(handle_messages(client1_incoming, client1.clone()))
435 .detach();
436 executor
437 .spawn(handle_messages(server_incoming2, server.clone()))
438 .detach();
439 executor
440 .spawn(handle_messages(client2_incoming, client2.clone()))
441 .detach();
442
443 assert_eq!(
444 client1
445 .request(client1_conn_id, proto::Ping {},)
446 .await
447 .unwrap(),
448 proto::Ack {}
449 );
450
451 assert_eq!(
452 client2
453 .request(client2_conn_id, proto::Ping {},)
454 .await
455 .unwrap(),
456 proto::Ack {}
457 );
458
459 assert_eq!(
460 client1
461 .request(client1_conn_id, proto::Test { id: 1 },)
462 .await
463 .unwrap(),
464 proto::Test { id: 1 }
465 );
466
467 assert_eq!(
468 client2
469 .request(client2_conn_id, proto::Test { id: 2 })
470 .await
471 .unwrap(),
472 proto::Test { id: 2 }
473 );
474
475 client1.disconnect(client1_conn_id);
476 client2.disconnect(client1_conn_id);
477
478 async fn handle_messages(
479 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
480 peer: Arc<Peer>,
481 ) -> Result<()> {
482 while let Some(envelope) = messages.next().await {
483 let envelope = envelope.into_any();
484 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
485 let receipt = envelope.receipt();
486 peer.respond(receipt, proto::Ack {})?
487 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
488 {
489 peer.respond(envelope.receipt(), envelope.payload.clone())?
490 } else {
491 panic!("unknown message type");
492 }
493 }
494
495 Ok(())
496 }
497 }
498
499 #[gpui::test(iterations = 50)]
500 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
501 let executor = cx.foreground();
502 let server = Peer::new();
503 let client = Peer::new();
504
505 let (client_to_server_conn, server_to_client_conn, _kill) =
506 Connection::in_memory(cx.background());
507 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
508 .add_test_connection(client_to_server_conn, cx.background())
509 .await;
510 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
511 .add_test_connection(server_to_client_conn, cx.background())
512 .await;
513
514 executor.spawn(io_task1).detach();
515 executor.spawn(io_task2).detach();
516
517 executor
518 .spawn(async move {
519 let request = server_incoming
520 .next()
521 .await
522 .unwrap()
523 .into_any()
524 .downcast::<TypedEnvelope<proto::Ping>>()
525 .unwrap();
526
527 server
528 .send(
529 server_to_client_conn_id,
530 proto::Error {
531 message: "message 1".to_string(),
532 },
533 )
534 .unwrap();
535 server
536 .send(
537 server_to_client_conn_id,
538 proto::Error {
539 message: "message 2".to_string(),
540 },
541 )
542 .unwrap();
543 server.respond(request.receipt(), proto::Ack {}).unwrap();
544
545 // Prevent the connection from being dropped
546 server_incoming.next().await;
547 })
548 .detach();
549
550 let events = Arc::new(Mutex::new(Vec::new()));
551
552 let response = client.request(client_to_server_conn_id, proto::Ping {});
553 let response_task = executor.spawn({
554 let events = events.clone();
555 async move {
556 response.await.unwrap();
557 events.lock().push("response".to_string());
558 }
559 });
560
561 executor
562 .spawn({
563 let events = events.clone();
564 async move {
565 let incoming1 = client_incoming
566 .next()
567 .await
568 .unwrap()
569 .into_any()
570 .downcast::<TypedEnvelope<proto::Error>>()
571 .unwrap();
572 events.lock().push(incoming1.payload.message);
573 let incoming2 = client_incoming
574 .next()
575 .await
576 .unwrap()
577 .into_any()
578 .downcast::<TypedEnvelope<proto::Error>>()
579 .unwrap();
580 events.lock().push(incoming2.payload.message);
581
582 // Prevent the connection from being dropped
583 client_incoming.next().await;
584 }
585 })
586 .detach();
587
588 response_task.await;
589 assert_eq!(
590 &*events.lock(),
591 &[
592 "message 1".to_string(),
593 "message 2".to_string(),
594 "response".to_string()
595 ]
596 );
597 }
598
599 #[gpui::test(iterations = 50)]
600 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
601 let executor = cx.foreground();
602 let server = Peer::new();
603 let client = Peer::new();
604
605 let (client_to_server_conn, server_to_client_conn, _kill) =
606 Connection::in_memory(cx.background());
607 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
608 .add_test_connection(client_to_server_conn, cx.background())
609 .await;
610 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
611 .add_test_connection(server_to_client_conn, cx.background())
612 .await;
613
614 executor.spawn(io_task1).detach();
615 executor.spawn(io_task2).detach();
616
617 executor
618 .spawn(async move {
619 let request1 = server_incoming
620 .next()
621 .await
622 .unwrap()
623 .into_any()
624 .downcast::<TypedEnvelope<proto::Ping>>()
625 .unwrap();
626 let request2 = server_incoming
627 .next()
628 .await
629 .unwrap()
630 .into_any()
631 .downcast::<TypedEnvelope<proto::Ping>>()
632 .unwrap();
633
634 server
635 .send(
636 server_to_client_conn_id,
637 proto::Error {
638 message: "message 1".to_string(),
639 },
640 )
641 .unwrap();
642 server
643 .send(
644 server_to_client_conn_id,
645 proto::Error {
646 message: "message 2".to_string(),
647 },
648 )
649 .unwrap();
650 server.respond(request1.receipt(), proto::Ack {}).unwrap();
651 server.respond(request2.receipt(), proto::Ack {}).unwrap();
652
653 // Prevent the connection from being dropped
654 server_incoming.next().await;
655 })
656 .detach();
657
658 let events = Arc::new(Mutex::new(Vec::new()));
659
660 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
661 let request1_task = executor.spawn(request1);
662 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
663 let request2_task = executor.spawn({
664 let events = events.clone();
665 async move {
666 request2.await.unwrap();
667 events.lock().push("response 2".to_string());
668 }
669 });
670
671 executor
672 .spawn({
673 let events = events.clone();
674 async move {
675 let incoming1 = client_incoming
676 .next()
677 .await
678 .unwrap()
679 .into_any()
680 .downcast::<TypedEnvelope<proto::Error>>()
681 .unwrap();
682 events.lock().push(incoming1.payload.message);
683 let incoming2 = client_incoming
684 .next()
685 .await
686 .unwrap()
687 .into_any()
688 .downcast::<TypedEnvelope<proto::Error>>()
689 .unwrap();
690 events.lock().push(incoming2.payload.message);
691
692 // Prevent the connection from being dropped
693 client_incoming.next().await;
694 }
695 })
696 .detach();
697
698 // Allow the request to make some progress before dropping it.
699 cx.background().simulate_random_delay().await;
700 drop(request1_task);
701
702 request2_task.await;
703 assert_eq!(
704 &*events.lock(),
705 &[
706 "message 1".to_string(),
707 "message 2".to_string(),
708 "response 2".to_string()
709 ]
710 );
711 }
712
713 #[gpui::test(iterations = 50)]
714 async fn test_disconnect(cx: &mut TestAppContext) {
715 let executor = cx.foreground();
716
717 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
718
719 let client = Peer::new();
720 let (connection_id, io_handler, mut incoming) = client
721 .add_test_connection(client_conn, cx.background())
722 .await;
723
724 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
725 executor
726 .spawn(async move {
727 io_handler.await.ok();
728 io_ended_tx.send(()).await.unwrap();
729 })
730 .detach();
731
732 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
733 executor
734 .spawn(async move {
735 incoming.next().await;
736 messages_ended_tx.send(()).await.unwrap();
737 })
738 .detach();
739
740 client.disconnect(connection_id);
741
742 io_ended_rx.recv().await;
743 messages_ended_rx.recv().await;
744 assert!(server_conn
745 .send(WebSocketMessage::Binary(vec![]))
746 .await
747 .is_err());
748 }
749
750 #[gpui::test(iterations = 50)]
751 async fn test_io_error(cx: &mut TestAppContext) {
752 let executor = cx.foreground();
753 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
754
755 let client = Peer::new();
756 let (connection_id, io_handler, mut incoming) = client
757 .add_test_connection(client_conn, cx.background())
758 .await;
759 executor.spawn(io_handler).detach();
760 executor
761 .spawn(async move { incoming.next().await })
762 .detach();
763
764 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
765 let _request = server_conn.rx.next().await.unwrap().unwrap();
766
767 drop(server_conn);
768 assert_eq!(
769 response.await.unwrap_err().to_string(),
770 "connection was closed"
771 );
772 }
773}