1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::{channel::oneshot, stream::BoxStream, FutureExt as _, StreamExt};
5use parking_lot::{Mutex, RwLock};
6use postage::{
7 barrier, mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use smol_timeout::TimeoutExt as _;
11use std::sync::atomic::Ordering::SeqCst;
12use std::{
13 collections::HashMap,
14 fmt,
15 future::Future,
16 marker::PhantomData,
17 sync::{
18 atomic::{self, AtomicU32},
19 Arc,
20 },
21 time::Duration,
22};
23
24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
25pub struct ConnectionId(pub u32);
26
27impl fmt::Display for ConnectionId {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 self.0.fmt(f)
30 }
31}
32
33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
34pub struct PeerId(pub u32);
35
36impl fmt::Display for PeerId {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 self.0.fmt(f)
39 }
40}
41
42pub struct Receipt<T> {
43 pub sender_id: ConnectionId,
44 pub message_id: u32,
45 payload_type: PhantomData<T>,
46}
47
48impl<T> Clone for Receipt<T> {
49 fn clone(&self) -> Self {
50 Self {
51 sender_id: self.sender_id,
52 message_id: self.message_id,
53 payload_type: PhantomData,
54 }
55 }
56}
57
58impl<T> Copy for Receipt<T> {}
59
60pub struct TypedEnvelope<T> {
61 pub sender_id: ConnectionId,
62 pub original_sender_id: Option<PeerId>,
63 pub message_id: u32,
64 pub payload: T,
65}
66
67impl<T> TypedEnvelope<T> {
68 pub fn original_sender_id(&self) -> Result<PeerId> {
69 self.original_sender_id
70 .ok_or_else(|| anyhow!("missing original_sender_id"))
71 }
72}
73
74impl<T: RequestMessage> TypedEnvelope<T> {
75 pub fn receipt(&self) -> Receipt<T> {
76 Receipt {
77 sender_id: self.sender_id,
78 message_id: self.message_id,
79 payload_type: PhantomData,
80 }
81 }
82}
83
84pub struct Peer {
85 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
86 next_connection_id: AtomicU32,
87}
88
89#[derive(Clone)]
90pub struct ConnectionState {
91 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Message>,
92 next_message_id: Arc<AtomicU32>,
93 response_channels:
94 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
95}
96
97const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
98const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
99const RECEIVE_TIMEOUT: Duration = Duration::from_secs(30);
100
101impl Peer {
102 pub fn new() -> Arc<Self> {
103 Arc::new(Self {
104 connections: Default::default(),
105 next_connection_id: Default::default(),
106 })
107 }
108
109 pub async fn add_connection<F, Fut, Out>(
110 self: &Arc<Self>,
111 connection: Connection,
112 create_timer: F,
113 ) -> (
114 ConnectionId,
115 impl Future<Output = anyhow::Result<()>> + Send,
116 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
117 )
118 where
119 F: Send + Fn(Duration) -> Fut,
120 Fut: Send + Future<Output = Out>,
121 Out: Send,
122 {
123 // For outgoing messages, use an unbounded channel so that application code
124 // can always send messages without yielding. For incoming messages, use a
125 // bounded channel so that other peers will receive backpressure if they send
126 // messages faster than this peer can process them.
127 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
128 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
129
130 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
131 let connection_state = ConnectionState {
132 outgoing_tx: outgoing_tx.clone(),
133 next_message_id: Default::default(),
134 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
135 };
136 let mut writer = MessageStream::new(connection.tx);
137 let mut reader = MessageStream::new(connection.rx);
138
139 let this = self.clone();
140 let response_channels = connection_state.response_channels.clone();
141 let handle_io = async move {
142 let _end_connection = util::defer(|| {
143 response_channels.lock().take();
144 this.connections.write().remove(&connection_id);
145 });
146
147 // Send messages on this frequency so the connection isn't closed.
148 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
149 futures::pin_mut!(keepalive_timer);
150
151 // Disconnect if we don't receive messages at least this frequently.
152 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
153 futures::pin_mut!(receive_timeout);
154
155 loop {
156 let read_message = reader.read().fuse();
157 futures::pin_mut!(read_message);
158
159 loop {
160 futures::select_biased! {
161 outgoing = outgoing_rx.next().fuse() => match outgoing {
162 Some(outgoing) => {
163 if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
164 result.context("failed to write RPC message")?;
165 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
166 } else {
167 Err(anyhow!("timed out writing message"))?;
168 }
169 }
170 None => return Ok(()),
171 },
172 incoming = read_message => {
173 let incoming = incoming.context("received invalid RPC message")?;
174 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
175 if let proto::Message::Envelope(incoming) = incoming {
176 if incoming_tx.send(incoming).await.is_err() {
177 return Ok(());
178 }
179 }
180 break;
181 },
182 _ = keepalive_timer => {
183 if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
184 result.context("failed to send keepalive")?;
185 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
186 } else {
187 Err(anyhow!("timed out sending keepalive"))?;
188 }
189 }
190 _ = receive_timeout => {
191 Err(anyhow!("delay between messages too long"))?
192 }
193 }
194 }
195 }
196 };
197
198 let response_channels = connection_state.response_channels.clone();
199 self.connections
200 .write()
201 .insert(connection_id, connection_state);
202
203 let incoming_rx = incoming_rx.filter_map(move |incoming| {
204 let response_channels = response_channels.clone();
205 async move {
206 if let Some(responding_to) = incoming.responding_to {
207 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
208 if let Some(tx) = channel {
209 let mut requester_resumed = barrier::channel();
210 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
211 log::debug!(
212 "received RPC but request future was dropped {:?}",
213 error.0
214 );
215 }
216 requester_resumed.1.recv().await;
217 } else {
218 log::warn!("received RPC response to unknown request {}", responding_to);
219 }
220
221 None
222 } else {
223 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
224 log::error!("unable to construct a typed envelope");
225 None
226 })
227 }
228 }
229 });
230 (connection_id, handle_io, incoming_rx.boxed())
231 }
232
233 #[cfg(any(test, feature = "test-support"))]
234 pub async fn add_test_connection(
235 self: &Arc<Self>,
236 connection: Connection,
237 executor: Arc<gpui::executor::Background>,
238 ) -> (
239 ConnectionId,
240 impl Future<Output = anyhow::Result<()>> + Send,
241 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
242 ) {
243 let executor = executor.clone();
244 self.add_connection(connection, move |duration| executor.timer(duration))
245 .await
246 }
247
248 pub fn disconnect(&self, connection_id: ConnectionId) {
249 self.connections.write().remove(&connection_id);
250 }
251
252 pub fn reset(&self) {
253 self.connections.write().clear();
254 }
255
256 pub fn request<T: RequestMessage>(
257 &self,
258 receiver_id: ConnectionId,
259 request: T,
260 ) -> impl Future<Output = Result<T::Response>> {
261 self.request_internal(None, receiver_id, request)
262 }
263
264 pub fn forward_request<T: RequestMessage>(
265 &self,
266 sender_id: ConnectionId,
267 receiver_id: ConnectionId,
268 request: T,
269 ) -> impl Future<Output = Result<T::Response>> {
270 self.request_internal(Some(sender_id), receiver_id, request)
271 }
272
273 pub fn request_internal<T: RequestMessage>(
274 &self,
275 original_sender_id: Option<ConnectionId>,
276 receiver_id: ConnectionId,
277 request: T,
278 ) -> impl Future<Output = Result<T::Response>> {
279 let (tx, rx) = oneshot::channel();
280 let send = self.connection_state(receiver_id).and_then(|connection| {
281 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
282 connection
283 .response_channels
284 .lock()
285 .as_mut()
286 .ok_or_else(|| anyhow!("connection was closed"))?
287 .insert(message_id, tx);
288 connection
289 .outgoing_tx
290 .unbounded_send(proto::Message::Envelope(request.into_envelope(
291 message_id,
292 None,
293 original_sender_id.map(|id| id.0),
294 )))
295 .map_err(|_| anyhow!("connection was closed"))?;
296 Ok(())
297 });
298 async move {
299 send?;
300 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
301 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
302 Err(anyhow!("RPC request failed - {}", error.message))
303 } else {
304 T::Response::from_envelope(response)
305 .ok_or_else(|| anyhow!("received response of the wrong type"))
306 }
307 }
308 }
309
310 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
311 let connection = self.connection_state(receiver_id)?;
312 let message_id = connection
313 .next_message_id
314 .fetch_add(1, atomic::Ordering::SeqCst);
315 connection
316 .outgoing_tx
317 .unbounded_send(proto::Message::Envelope(
318 message.into_envelope(message_id, None, None),
319 ))?;
320 Ok(())
321 }
322
323 pub fn forward_send<T: EnvelopedMessage>(
324 &self,
325 sender_id: ConnectionId,
326 receiver_id: ConnectionId,
327 message: T,
328 ) -> Result<()> {
329 let connection = self.connection_state(receiver_id)?;
330 let message_id = connection
331 .next_message_id
332 .fetch_add(1, atomic::Ordering::SeqCst);
333 connection
334 .outgoing_tx
335 .unbounded_send(proto::Message::Envelope(message.into_envelope(
336 message_id,
337 None,
338 Some(sender_id.0),
339 )))?;
340 Ok(())
341 }
342
343 pub fn respond<T: RequestMessage>(
344 &self,
345 receipt: Receipt<T>,
346 response: T::Response,
347 ) -> Result<()> {
348 let connection = self.connection_state(receipt.sender_id)?;
349 let message_id = connection
350 .next_message_id
351 .fetch_add(1, atomic::Ordering::SeqCst);
352 connection
353 .outgoing_tx
354 .unbounded_send(proto::Message::Envelope(response.into_envelope(
355 message_id,
356 Some(receipt.message_id),
357 None,
358 )))?;
359 Ok(())
360 }
361
362 pub fn respond_with_error<T: RequestMessage>(
363 &self,
364 receipt: Receipt<T>,
365 response: proto::Error,
366 ) -> Result<()> {
367 let connection = self.connection_state(receipt.sender_id)?;
368 let message_id = connection
369 .next_message_id
370 .fetch_add(1, atomic::Ordering::SeqCst);
371 connection
372 .outgoing_tx
373 .unbounded_send(proto::Message::Envelope(response.into_envelope(
374 message_id,
375 Some(receipt.message_id),
376 None,
377 )))?;
378 Ok(())
379 }
380
381 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
382 let connections = self.connections.read();
383 let connection = connections
384 .get(&connection_id)
385 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
386 Ok(connection.clone())
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393 use crate::TypedEnvelope;
394 use async_tungstenite::tungstenite::Message as WebSocketMessage;
395 use gpui::TestAppContext;
396
397 #[gpui::test(iterations = 50)]
398 async fn test_request_response(cx: &mut TestAppContext) {
399 let executor = cx.foreground();
400
401 // create 2 clients connected to 1 server
402 let server = Peer::new();
403 let client1 = Peer::new();
404 let client2 = Peer::new();
405
406 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
407 Connection::in_memory(cx.background());
408 let (client1_conn_id, io_task1, client1_incoming) = client1
409 .add_test_connection(client1_to_server_conn, cx.background())
410 .await;
411 let (_, io_task2, server_incoming1) = server
412 .add_test_connection(server_to_client_1_conn, cx.background())
413 .await;
414
415 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
416 Connection::in_memory(cx.background());
417 let (client2_conn_id, io_task3, client2_incoming) = client2
418 .add_test_connection(client2_to_server_conn, cx.background())
419 .await;
420 let (_, io_task4, server_incoming2) = server
421 .add_test_connection(server_to_client_2_conn, cx.background())
422 .await;
423
424 executor.spawn(io_task1).detach();
425 executor.spawn(io_task2).detach();
426 executor.spawn(io_task3).detach();
427 executor.spawn(io_task4).detach();
428 executor
429 .spawn(handle_messages(server_incoming1, server.clone()))
430 .detach();
431 executor
432 .spawn(handle_messages(client1_incoming, client1.clone()))
433 .detach();
434 executor
435 .spawn(handle_messages(server_incoming2, server.clone()))
436 .detach();
437 executor
438 .spawn(handle_messages(client2_incoming, client2.clone()))
439 .detach();
440
441 assert_eq!(
442 client1
443 .request(client1_conn_id, proto::Ping {},)
444 .await
445 .unwrap(),
446 proto::Ack {}
447 );
448
449 assert_eq!(
450 client2
451 .request(client2_conn_id, proto::Ping {},)
452 .await
453 .unwrap(),
454 proto::Ack {}
455 );
456
457 assert_eq!(
458 client1
459 .request(client1_conn_id, proto::Test { id: 1 },)
460 .await
461 .unwrap(),
462 proto::Test { id: 1 }
463 );
464
465 assert_eq!(
466 client2
467 .request(client2_conn_id, proto::Test { id: 2 })
468 .await
469 .unwrap(),
470 proto::Test { id: 2 }
471 );
472
473 client1.disconnect(client1_conn_id);
474 client2.disconnect(client1_conn_id);
475
476 async fn handle_messages(
477 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
478 peer: Arc<Peer>,
479 ) -> Result<()> {
480 while let Some(envelope) = messages.next().await {
481 let envelope = envelope.into_any();
482 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
483 let receipt = envelope.receipt();
484 peer.respond(receipt, proto::Ack {})?
485 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
486 {
487 peer.respond(envelope.receipt(), envelope.payload.clone())?
488 } else {
489 panic!("unknown message type");
490 }
491 }
492
493 Ok(())
494 }
495 }
496
497 #[gpui::test(iterations = 50)]
498 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
499 let executor = cx.foreground();
500 let server = Peer::new();
501 let client = Peer::new();
502
503 let (client_to_server_conn, server_to_client_conn, _kill) =
504 Connection::in_memory(cx.background());
505 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
506 .add_test_connection(client_to_server_conn, cx.background())
507 .await;
508 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
509 .add_test_connection(server_to_client_conn, cx.background())
510 .await;
511
512 executor.spawn(io_task1).detach();
513 executor.spawn(io_task2).detach();
514
515 executor
516 .spawn(async move {
517 let request = server_incoming
518 .next()
519 .await
520 .unwrap()
521 .into_any()
522 .downcast::<TypedEnvelope<proto::Ping>>()
523 .unwrap();
524
525 server
526 .send(
527 server_to_client_conn_id,
528 proto::Error {
529 message: "message 1".to_string(),
530 },
531 )
532 .unwrap();
533 server
534 .send(
535 server_to_client_conn_id,
536 proto::Error {
537 message: "message 2".to_string(),
538 },
539 )
540 .unwrap();
541 server.respond(request.receipt(), proto::Ack {}).unwrap();
542
543 // Prevent the connection from being dropped
544 server_incoming.next().await;
545 })
546 .detach();
547
548 let events = Arc::new(Mutex::new(Vec::new()));
549
550 let response = client.request(client_to_server_conn_id, proto::Ping {});
551 let response_task = executor.spawn({
552 let events = events.clone();
553 async move {
554 response.await.unwrap();
555 events.lock().push("response".to_string());
556 }
557 });
558
559 executor
560 .spawn({
561 let events = events.clone();
562 async move {
563 let incoming1 = client_incoming
564 .next()
565 .await
566 .unwrap()
567 .into_any()
568 .downcast::<TypedEnvelope<proto::Error>>()
569 .unwrap();
570 events.lock().push(incoming1.payload.message);
571 let incoming2 = client_incoming
572 .next()
573 .await
574 .unwrap()
575 .into_any()
576 .downcast::<TypedEnvelope<proto::Error>>()
577 .unwrap();
578 events.lock().push(incoming2.payload.message);
579
580 // Prevent the connection from being dropped
581 client_incoming.next().await;
582 }
583 })
584 .detach();
585
586 response_task.await;
587 assert_eq!(
588 &*events.lock(),
589 &[
590 "message 1".to_string(),
591 "message 2".to_string(),
592 "response".to_string()
593 ]
594 );
595 }
596
597 #[gpui::test(iterations = 50)]
598 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
599 let executor = cx.foreground();
600 let server = Peer::new();
601 let client = Peer::new();
602
603 let (client_to_server_conn, server_to_client_conn, _kill) =
604 Connection::in_memory(cx.background());
605 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
606 .add_test_connection(client_to_server_conn, cx.background())
607 .await;
608 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
609 .add_test_connection(server_to_client_conn, cx.background())
610 .await;
611
612 executor.spawn(io_task1).detach();
613 executor.spawn(io_task2).detach();
614
615 executor
616 .spawn(async move {
617 let request1 = server_incoming
618 .next()
619 .await
620 .unwrap()
621 .into_any()
622 .downcast::<TypedEnvelope<proto::Ping>>()
623 .unwrap();
624 let request2 = server_incoming
625 .next()
626 .await
627 .unwrap()
628 .into_any()
629 .downcast::<TypedEnvelope<proto::Ping>>()
630 .unwrap();
631
632 server
633 .send(
634 server_to_client_conn_id,
635 proto::Error {
636 message: "message 1".to_string(),
637 },
638 )
639 .unwrap();
640 server
641 .send(
642 server_to_client_conn_id,
643 proto::Error {
644 message: "message 2".to_string(),
645 },
646 )
647 .unwrap();
648 server.respond(request1.receipt(), proto::Ack {}).unwrap();
649 server.respond(request2.receipt(), proto::Ack {}).unwrap();
650
651 // Prevent the connection from being dropped
652 server_incoming.next().await;
653 })
654 .detach();
655
656 let events = Arc::new(Mutex::new(Vec::new()));
657
658 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
659 let request1_task = executor.spawn(request1);
660 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
661 let request2_task = executor.spawn({
662 let events = events.clone();
663 async move {
664 request2.await.unwrap();
665 events.lock().push("response 2".to_string());
666 }
667 });
668
669 executor
670 .spawn({
671 let events = events.clone();
672 async move {
673 let incoming1 = client_incoming
674 .next()
675 .await
676 .unwrap()
677 .into_any()
678 .downcast::<TypedEnvelope<proto::Error>>()
679 .unwrap();
680 events.lock().push(incoming1.payload.message);
681 let incoming2 = client_incoming
682 .next()
683 .await
684 .unwrap()
685 .into_any()
686 .downcast::<TypedEnvelope<proto::Error>>()
687 .unwrap();
688 events.lock().push(incoming2.payload.message);
689
690 // Prevent the connection from being dropped
691 client_incoming.next().await;
692 }
693 })
694 .detach();
695
696 // Allow the request to make some progress before dropping it.
697 cx.background().simulate_random_delay().await;
698 drop(request1_task);
699
700 request2_task.await;
701 assert_eq!(
702 &*events.lock(),
703 &[
704 "message 1".to_string(),
705 "message 2".to_string(),
706 "response 2".to_string()
707 ]
708 );
709 }
710
711 #[gpui::test(iterations = 50)]
712 async fn test_disconnect(cx: &mut TestAppContext) {
713 let executor = cx.foreground();
714
715 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
716
717 let client = Peer::new();
718 let (connection_id, io_handler, mut incoming) = client
719 .add_test_connection(client_conn, cx.background())
720 .await;
721
722 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
723 executor
724 .spawn(async move {
725 io_handler.await.ok();
726 io_ended_tx.send(()).await.unwrap();
727 })
728 .detach();
729
730 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
731 executor
732 .spawn(async move {
733 incoming.next().await;
734 messages_ended_tx.send(()).await.unwrap();
735 })
736 .detach();
737
738 client.disconnect(connection_id);
739
740 io_ended_rx.recv().await;
741 messages_ended_rx.recv().await;
742 assert!(server_conn
743 .send(WebSocketMessage::Binary(vec![]))
744 .await
745 .is_err());
746 }
747
748 #[gpui::test(iterations = 50)]
749 async fn test_io_error(cx: &mut TestAppContext) {
750 let executor = cx.foreground();
751 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
752
753 let client = Peer::new();
754 let (connection_id, io_handler, mut incoming) = client
755 .add_test_connection(client_conn, cx.background())
756 .await;
757 executor.spawn(io_handler).detach();
758 executor
759 .spawn(async move { incoming.next().await })
760 .detach();
761
762 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
763 let _request = server_conn.rx.next().await.unwrap().unwrap();
764
765 drop(server_conn);
766 assert_eq!(
767 response.await.unwrap_err().to_string(),
768 "connection was closed"
769 );
770 }
771}