1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::{channel::oneshot, stream::BoxStream, FutureExt as _, StreamExt};
5use parking_lot::{Mutex, RwLock};
6use postage::{
7 barrier, mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use smol_timeout::TimeoutExt as _;
11use std::sync::atomic::Ordering::SeqCst;
12use std::{
13 collections::HashMap,
14 fmt,
15 future::Future,
16 marker::PhantomData,
17 sync::{
18 atomic::{self, AtomicU32},
19 Arc,
20 },
21 time::Duration,
22};
23
24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
25pub struct ConnectionId(pub u32);
26
27impl fmt::Display for ConnectionId {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 self.0.fmt(f)
30 }
31}
32
33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
34pub struct PeerId(pub u32);
35
36impl fmt::Display for PeerId {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 self.0.fmt(f)
39 }
40}
41
42pub struct Receipt<T> {
43 pub sender_id: ConnectionId,
44 pub message_id: u32,
45 payload_type: PhantomData<T>,
46}
47
48impl<T> Clone for Receipt<T> {
49 fn clone(&self) -> Self {
50 Self {
51 sender_id: self.sender_id,
52 message_id: self.message_id,
53 payload_type: PhantomData,
54 }
55 }
56}
57
58impl<T> Copy for Receipt<T> {}
59
60pub struct TypedEnvelope<T> {
61 pub sender_id: ConnectionId,
62 pub original_sender_id: Option<PeerId>,
63 pub message_id: u32,
64 pub payload: T,
65}
66
67impl<T> TypedEnvelope<T> {
68 pub fn original_sender_id(&self) -> Result<PeerId> {
69 self.original_sender_id
70 .ok_or_else(|| anyhow!("missing original_sender_id"))
71 }
72}
73
74impl<T: RequestMessage> TypedEnvelope<T> {
75 pub fn receipt(&self) -> Receipt<T> {
76 Receipt {
77 sender_id: self.sender_id,
78 message_id: self.message_id,
79 payload_type: PhantomData,
80 }
81 }
82}
83
84pub struct Peer {
85 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
86 next_connection_id: AtomicU32,
87}
88
89#[derive(Clone)]
90pub struct ConnectionState {
91 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Message>,
92 next_message_id: Arc<AtomicU32>,
93 response_channels:
94 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
95}
96
97const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
98const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
99
100impl Peer {
101 pub fn new() -> Arc<Self> {
102 Arc::new(Self {
103 connections: Default::default(),
104 next_connection_id: Default::default(),
105 })
106 }
107
108 pub async fn add_connection<F, Fut, Out>(
109 self: &Arc<Self>,
110 connection: Connection,
111 create_timer: F,
112 ) -> (
113 ConnectionId,
114 impl Future<Output = anyhow::Result<()>> + Send,
115 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
116 )
117 where
118 F: Send + Fn(Duration) -> Fut,
119 Fut: Send + Future<Output = Out>,
120 Out: Send,
121 {
122 // For outgoing messages, use an unbounded channel so that application code
123 // can always send messages without yielding. For incoming messages, use a
124 // bounded channel so that other peers will receive backpressure if they send
125 // messages faster than this peer can process them.
126 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
127 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
128
129 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
130 let connection_state = ConnectionState {
131 outgoing_tx: outgoing_tx.clone(),
132 next_message_id: Default::default(),
133 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
134 };
135 let mut writer = MessageStream::new(connection.tx);
136 let mut reader = MessageStream::new(connection.rx);
137
138 let this = self.clone();
139 let response_channels = connection_state.response_channels.clone();
140 let handle_io = async move {
141 let _end_connection = util::defer(|| {
142 response_channels.lock().take();
143 this.connections.write().remove(&connection_id);
144 });
145
146 // Send messages on this frequency so the connection isn't closed.
147 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
148 futures::pin_mut!(keepalive_timer);
149
150 loop {
151 let read_message = reader.read().fuse();
152 futures::pin_mut!(read_message);
153
154 // Disconnect if we don't receive messages at least this frequently.
155 let receive_timeout = create_timer(3 * KEEPALIVE_INTERVAL).fuse();
156 futures::pin_mut!(receive_timeout);
157
158 loop {
159 futures::select_biased! {
160 outgoing = outgoing_rx.next().fuse() => match outgoing {
161 Some(outgoing) => {
162 if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
163 result.context("failed to write RPC message")?;
164 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
165 } else {
166 Err(anyhow!("timed out writing message"))?;
167 }
168 }
169 None => return Ok(()),
170 },
171 incoming = read_message => {
172 let incoming = incoming.context("received invalid RPC message")?;
173 if let proto::Message::Envelope(incoming) = incoming {
174 if incoming_tx.send(incoming).await.is_err() {
175 return Ok(());
176 }
177 }
178 break;
179 },
180 _ = keepalive_timer => {
181 if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
182 result.context("failed to send keepalive")?;
183 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
184 } else {
185 Err(anyhow!("timed out sending keepalive"))?;
186 }
187 }
188 _ = receive_timeout => {
189 Err(anyhow!("delay between messages too long"))?
190 }
191 }
192 }
193 }
194 };
195
196 let response_channels = connection_state.response_channels.clone();
197 self.connections
198 .write()
199 .insert(connection_id, connection_state);
200
201 let incoming_rx = incoming_rx.filter_map(move |incoming| {
202 let response_channels = response_channels.clone();
203 async move {
204 if let Some(responding_to) = incoming.responding_to {
205 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
206 if let Some(tx) = channel {
207 let mut requester_resumed = barrier::channel();
208 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
209 log::debug!(
210 "received RPC but request future was dropped {:?}",
211 error.0
212 );
213 }
214 requester_resumed.1.recv().await;
215 } else {
216 log::warn!("received RPC response to unknown request {}", responding_to);
217 }
218
219 None
220 } else {
221 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
222 log::error!("unable to construct a typed envelope");
223 None
224 })
225 }
226 }
227 });
228 (connection_id, handle_io, incoming_rx.boxed())
229 }
230
231 #[cfg(any(test, feature = "test-support"))]
232 pub async fn add_test_connection(
233 self: &Arc<Self>,
234 connection: Connection,
235 executor: Arc<gpui::executor::Background>,
236 ) -> (
237 ConnectionId,
238 impl Future<Output = anyhow::Result<()>> + Send,
239 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
240 ) {
241 let executor = executor.clone();
242 self.add_connection(connection, move |duration| executor.timer(duration))
243 .await
244 }
245
246 pub fn disconnect(&self, connection_id: ConnectionId) {
247 self.connections.write().remove(&connection_id);
248 }
249
250 pub fn reset(&self) {
251 self.connections.write().clear();
252 }
253
254 pub fn request<T: RequestMessage>(
255 &self,
256 receiver_id: ConnectionId,
257 request: T,
258 ) -> impl Future<Output = Result<T::Response>> {
259 self.request_internal(None, receiver_id, request)
260 }
261
262 pub fn forward_request<T: RequestMessage>(
263 &self,
264 sender_id: ConnectionId,
265 receiver_id: ConnectionId,
266 request: T,
267 ) -> impl Future<Output = Result<T::Response>> {
268 self.request_internal(Some(sender_id), receiver_id, request)
269 }
270
271 pub fn request_internal<T: RequestMessage>(
272 &self,
273 original_sender_id: Option<ConnectionId>,
274 receiver_id: ConnectionId,
275 request: T,
276 ) -> impl Future<Output = Result<T::Response>> {
277 let (tx, rx) = oneshot::channel();
278 let send = self.connection_state(receiver_id).and_then(|connection| {
279 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
280 connection
281 .response_channels
282 .lock()
283 .as_mut()
284 .ok_or_else(|| anyhow!("connection was closed"))?
285 .insert(message_id, tx);
286 connection
287 .outgoing_tx
288 .unbounded_send(proto::Message::Envelope(request.into_envelope(
289 message_id,
290 None,
291 original_sender_id.map(|id| id.0),
292 )))
293 .map_err(|_| anyhow!("connection was closed"))?;
294 Ok(())
295 });
296 async move {
297 send?;
298 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
299 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
300 Err(anyhow!("RPC request failed - {}", error.message))
301 } else {
302 T::Response::from_envelope(response)
303 .ok_or_else(|| anyhow!("received response of the wrong type"))
304 }
305 }
306 }
307
308 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
309 let connection = self.connection_state(receiver_id)?;
310 let message_id = connection
311 .next_message_id
312 .fetch_add(1, atomic::Ordering::SeqCst);
313 connection
314 .outgoing_tx
315 .unbounded_send(proto::Message::Envelope(
316 message.into_envelope(message_id, None, None),
317 ))?;
318 Ok(())
319 }
320
321 pub fn forward_send<T: EnvelopedMessage>(
322 &self,
323 sender_id: ConnectionId,
324 receiver_id: ConnectionId,
325 message: T,
326 ) -> Result<()> {
327 let connection = self.connection_state(receiver_id)?;
328 let message_id = connection
329 .next_message_id
330 .fetch_add(1, atomic::Ordering::SeqCst);
331 connection
332 .outgoing_tx
333 .unbounded_send(proto::Message::Envelope(message.into_envelope(
334 message_id,
335 None,
336 Some(sender_id.0),
337 )))?;
338 Ok(())
339 }
340
341 pub fn respond<T: RequestMessage>(
342 &self,
343 receipt: Receipt<T>,
344 response: T::Response,
345 ) -> Result<()> {
346 let connection = self.connection_state(receipt.sender_id)?;
347 let message_id = connection
348 .next_message_id
349 .fetch_add(1, atomic::Ordering::SeqCst);
350 connection
351 .outgoing_tx
352 .unbounded_send(proto::Message::Envelope(response.into_envelope(
353 message_id,
354 Some(receipt.message_id),
355 None,
356 )))?;
357 Ok(())
358 }
359
360 pub fn respond_with_error<T: RequestMessage>(
361 &self,
362 receipt: Receipt<T>,
363 response: proto::Error,
364 ) -> Result<()> {
365 let connection = self.connection_state(receipt.sender_id)?;
366 let message_id = connection
367 .next_message_id
368 .fetch_add(1, atomic::Ordering::SeqCst);
369 connection
370 .outgoing_tx
371 .unbounded_send(proto::Message::Envelope(response.into_envelope(
372 message_id,
373 Some(receipt.message_id),
374 None,
375 )))?;
376 Ok(())
377 }
378
379 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
380 let connections = self.connections.read();
381 let connection = connections
382 .get(&connection_id)
383 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
384 Ok(connection.clone())
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use crate::TypedEnvelope;
392 use async_tungstenite::tungstenite::Message as WebSocketMessage;
393 use gpui::TestAppContext;
394
395 #[gpui::test(iterations = 50)]
396 async fn test_request_response(cx: &mut TestAppContext) {
397 let executor = cx.foreground();
398
399 // create 2 clients connected to 1 server
400 let server = Peer::new();
401 let client1 = Peer::new();
402 let client2 = Peer::new();
403
404 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
405 Connection::in_memory(cx.background());
406 let (client1_conn_id, io_task1, client1_incoming) = client1
407 .add_test_connection(client1_to_server_conn, cx.background())
408 .await;
409 let (_, io_task2, server_incoming1) = server
410 .add_test_connection(server_to_client_1_conn, cx.background())
411 .await;
412
413 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
414 Connection::in_memory(cx.background());
415 let (client2_conn_id, io_task3, client2_incoming) = client2
416 .add_test_connection(client2_to_server_conn, cx.background())
417 .await;
418 let (_, io_task4, server_incoming2) = server
419 .add_test_connection(server_to_client_2_conn, cx.background())
420 .await;
421
422 executor.spawn(io_task1).detach();
423 executor.spawn(io_task2).detach();
424 executor.spawn(io_task3).detach();
425 executor.spawn(io_task4).detach();
426 executor
427 .spawn(handle_messages(server_incoming1, server.clone()))
428 .detach();
429 executor
430 .spawn(handle_messages(client1_incoming, client1.clone()))
431 .detach();
432 executor
433 .spawn(handle_messages(server_incoming2, server.clone()))
434 .detach();
435 executor
436 .spawn(handle_messages(client2_incoming, client2.clone()))
437 .detach();
438
439 assert_eq!(
440 client1
441 .request(client1_conn_id, proto::Ping {},)
442 .await
443 .unwrap(),
444 proto::Ack {}
445 );
446
447 assert_eq!(
448 client2
449 .request(client2_conn_id, proto::Ping {},)
450 .await
451 .unwrap(),
452 proto::Ack {}
453 );
454
455 assert_eq!(
456 client1
457 .request(client1_conn_id, proto::Test { id: 1 },)
458 .await
459 .unwrap(),
460 proto::Test { id: 1 }
461 );
462
463 assert_eq!(
464 client2
465 .request(client2_conn_id, proto::Test { id: 2 })
466 .await
467 .unwrap(),
468 proto::Test { id: 2 }
469 );
470
471 client1.disconnect(client1_conn_id);
472 client2.disconnect(client1_conn_id);
473
474 async fn handle_messages(
475 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
476 peer: Arc<Peer>,
477 ) -> Result<()> {
478 while let Some(envelope) = messages.next().await {
479 let envelope = envelope.into_any();
480 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
481 let receipt = envelope.receipt();
482 peer.respond(receipt, proto::Ack {})?
483 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
484 {
485 peer.respond(envelope.receipt(), envelope.payload.clone())?
486 } else {
487 panic!("unknown message type");
488 }
489 }
490
491 Ok(())
492 }
493 }
494
495 #[gpui::test(iterations = 50)]
496 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
497 let executor = cx.foreground();
498 let server = Peer::new();
499 let client = Peer::new();
500
501 let (client_to_server_conn, server_to_client_conn, _kill) =
502 Connection::in_memory(cx.background());
503 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
504 .add_test_connection(client_to_server_conn, cx.background())
505 .await;
506 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
507 .add_test_connection(server_to_client_conn, cx.background())
508 .await;
509
510 executor.spawn(io_task1).detach();
511 executor.spawn(io_task2).detach();
512
513 executor
514 .spawn(async move {
515 let request = server_incoming
516 .next()
517 .await
518 .unwrap()
519 .into_any()
520 .downcast::<TypedEnvelope<proto::Ping>>()
521 .unwrap();
522
523 server
524 .send(
525 server_to_client_conn_id,
526 proto::Error {
527 message: "message 1".to_string(),
528 },
529 )
530 .unwrap();
531 server
532 .send(
533 server_to_client_conn_id,
534 proto::Error {
535 message: "message 2".to_string(),
536 },
537 )
538 .unwrap();
539 server.respond(request.receipt(), proto::Ack {}).unwrap();
540
541 // Prevent the connection from being dropped
542 server_incoming.next().await;
543 })
544 .detach();
545
546 let events = Arc::new(Mutex::new(Vec::new()));
547
548 let response = client.request(client_to_server_conn_id, proto::Ping {});
549 let response_task = executor.spawn({
550 let events = events.clone();
551 async move {
552 response.await.unwrap();
553 events.lock().push("response".to_string());
554 }
555 });
556
557 executor
558 .spawn({
559 let events = events.clone();
560 async move {
561 let incoming1 = client_incoming
562 .next()
563 .await
564 .unwrap()
565 .into_any()
566 .downcast::<TypedEnvelope<proto::Error>>()
567 .unwrap();
568 events.lock().push(incoming1.payload.message);
569 let incoming2 = client_incoming
570 .next()
571 .await
572 .unwrap()
573 .into_any()
574 .downcast::<TypedEnvelope<proto::Error>>()
575 .unwrap();
576 events.lock().push(incoming2.payload.message);
577
578 // Prevent the connection from being dropped
579 client_incoming.next().await;
580 }
581 })
582 .detach();
583
584 response_task.await;
585 assert_eq!(
586 &*events.lock(),
587 &[
588 "message 1".to_string(),
589 "message 2".to_string(),
590 "response".to_string()
591 ]
592 );
593 }
594
595 #[gpui::test(iterations = 50)]
596 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
597 let executor = cx.foreground();
598 let server = Peer::new();
599 let client = Peer::new();
600
601 let (client_to_server_conn, server_to_client_conn, _kill) =
602 Connection::in_memory(cx.background());
603 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
604 .add_test_connection(client_to_server_conn, cx.background())
605 .await;
606 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
607 .add_test_connection(server_to_client_conn, cx.background())
608 .await;
609
610 executor.spawn(io_task1).detach();
611 executor.spawn(io_task2).detach();
612
613 executor
614 .spawn(async move {
615 let request1 = server_incoming
616 .next()
617 .await
618 .unwrap()
619 .into_any()
620 .downcast::<TypedEnvelope<proto::Ping>>()
621 .unwrap();
622 let request2 = server_incoming
623 .next()
624 .await
625 .unwrap()
626 .into_any()
627 .downcast::<TypedEnvelope<proto::Ping>>()
628 .unwrap();
629
630 server
631 .send(
632 server_to_client_conn_id,
633 proto::Error {
634 message: "message 1".to_string(),
635 },
636 )
637 .unwrap();
638 server
639 .send(
640 server_to_client_conn_id,
641 proto::Error {
642 message: "message 2".to_string(),
643 },
644 )
645 .unwrap();
646 server.respond(request1.receipt(), proto::Ack {}).unwrap();
647 server.respond(request2.receipt(), proto::Ack {}).unwrap();
648
649 // Prevent the connection from being dropped
650 server_incoming.next().await;
651 })
652 .detach();
653
654 let events = Arc::new(Mutex::new(Vec::new()));
655
656 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
657 let request1_task = executor.spawn(request1);
658 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
659 let request2_task = executor.spawn({
660 let events = events.clone();
661 async move {
662 request2.await.unwrap();
663 events.lock().push("response 2".to_string());
664 }
665 });
666
667 executor
668 .spawn({
669 let events = events.clone();
670 async move {
671 let incoming1 = client_incoming
672 .next()
673 .await
674 .unwrap()
675 .into_any()
676 .downcast::<TypedEnvelope<proto::Error>>()
677 .unwrap();
678 events.lock().push(incoming1.payload.message);
679 let incoming2 = client_incoming
680 .next()
681 .await
682 .unwrap()
683 .into_any()
684 .downcast::<TypedEnvelope<proto::Error>>()
685 .unwrap();
686 events.lock().push(incoming2.payload.message);
687
688 // Prevent the connection from being dropped
689 client_incoming.next().await;
690 }
691 })
692 .detach();
693
694 // Allow the request to make some progress before dropping it.
695 cx.background().simulate_random_delay().await;
696 drop(request1_task);
697
698 request2_task.await;
699 assert_eq!(
700 &*events.lock(),
701 &[
702 "message 1".to_string(),
703 "message 2".to_string(),
704 "response 2".to_string()
705 ]
706 );
707 }
708
709 #[gpui::test(iterations = 50)]
710 async fn test_disconnect(cx: &mut TestAppContext) {
711 let executor = cx.foreground();
712
713 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
714
715 let client = Peer::new();
716 let (connection_id, io_handler, mut incoming) = client
717 .add_test_connection(client_conn, cx.background())
718 .await;
719
720 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
721 executor
722 .spawn(async move {
723 io_handler.await.ok();
724 io_ended_tx.send(()).await.unwrap();
725 })
726 .detach();
727
728 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
729 executor
730 .spawn(async move {
731 incoming.next().await;
732 messages_ended_tx.send(()).await.unwrap();
733 })
734 .detach();
735
736 client.disconnect(connection_id);
737
738 io_ended_rx.recv().await;
739 messages_ended_rx.recv().await;
740 assert!(server_conn
741 .send(WebSocketMessage::Binary(vec![]))
742 .await
743 .is_err());
744 }
745
746 #[gpui::test(iterations = 50)]
747 async fn test_io_error(cx: &mut TestAppContext) {
748 let executor = cx.foreground();
749 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
750
751 let client = Peer::new();
752 let (connection_id, io_handler, mut incoming) = client
753 .add_test_connection(client_conn, cx.background())
754 .await;
755 executor.spawn(io_handler).detach();
756 executor
757 .spawn(async move { incoming.next().await })
758 .detach();
759
760 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
761 let _request = server_conn.rx.next().await.unwrap().unwrap();
762
763 drop(server_conn);
764 assert_eq!(
765 response.await.unwrap_err().to_string(),
766 "connection was closed"
767 );
768 }
769}