1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::BoxStream,
10 FutureExt, SinkExt, StreamExt,
11};
12use parking_lot::{Mutex, RwLock};
13use smol_timeout::TimeoutExt;
14use std::sync::atomic::Ordering::SeqCst;
15use std::{
16 fmt,
17 future::Future,
18 marker::PhantomData,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23 time::Duration,
24};
25
26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub struct ConnectionId(pub u32);
28
29impl fmt::Display for ConnectionId {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 self.0.fmt(f)
32 }
33}
34
35#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
36pub struct PeerId(pub u32);
37
38impl fmt::Display for PeerId {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 self.0.fmt(f)
41 }
42}
43
44pub struct Receipt<T> {
45 pub sender_id: ConnectionId,
46 pub message_id: u32,
47 payload_type: PhantomData<T>,
48}
49
50impl<T> Clone for Receipt<T> {
51 fn clone(&self) -> Self {
52 Self {
53 sender_id: self.sender_id,
54 message_id: self.message_id,
55 payload_type: PhantomData,
56 }
57 }
58}
59
60impl<T> Copy for Receipt<T> {}
61
62pub struct TypedEnvelope<T> {
63 pub sender_id: ConnectionId,
64 pub original_sender_id: Option<PeerId>,
65 pub message_id: u32,
66 pub payload: T,
67}
68
69impl<T> TypedEnvelope<T> {
70 pub fn original_sender_id(&self) -> Result<PeerId> {
71 self.original_sender_id
72 .ok_or_else(|| anyhow!("missing original_sender_id"))
73 }
74}
75
76impl<T: RequestMessage> TypedEnvelope<T> {
77 pub fn receipt(&self) -> Receipt<T> {
78 Receipt {
79 sender_id: self.sender_id,
80 message_id: self.message_id,
81 payload_type: PhantomData,
82 }
83 }
84}
85
86pub struct Peer {
87 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
88 next_connection_id: AtomicU32,
89}
90
91#[derive(Clone)]
92pub struct ConnectionState {
93 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
94 next_message_id: Arc<AtomicU32>,
95 response_channels:
96 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
97}
98
99const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
100const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
101pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
102
103impl Peer {
104 pub fn new() -> Arc<Self> {
105 Arc::new(Self {
106 connections: Default::default(),
107 next_connection_id: Default::default(),
108 })
109 }
110
111 pub async fn add_connection<F, Fut, Out>(
112 self: &Arc<Self>,
113 connection: Connection,
114 create_timer: F,
115 ) -> (
116 ConnectionId,
117 impl Future<Output = anyhow::Result<()>> + Send,
118 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
119 )
120 where
121 F: Send + Fn(Duration) -> Fut,
122 Fut: Send + Future<Output = Out>,
123 Out: Send,
124 {
125 // For outgoing messages, use an unbounded channel so that application code
126 // can always send messages without yielding. For incoming messages, use a
127 // bounded channel so that other peers will receive backpressure if they send
128 // messages faster than this peer can process them.
129 #[cfg(any(test, feature = "test-support"))]
130 const INCOMING_BUFFER_SIZE: usize = 1;
131 #[cfg(not(any(test, feature = "test-support")))]
132 const INCOMING_BUFFER_SIZE: usize = 64;
133 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
134 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
135
136 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
137 let connection_state = ConnectionState {
138 outgoing_tx: outgoing_tx.clone(),
139 next_message_id: Default::default(),
140 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
141 };
142 let mut writer = MessageStream::new(connection.tx);
143 let mut reader = MessageStream::new(connection.rx);
144
145 let this = self.clone();
146 let response_channels = connection_state.response_channels.clone();
147 let handle_io = async move {
148 let _end_connection = util::defer(|| {
149 response_channels.lock().take();
150 this.connections.write().remove(&connection_id);
151 });
152
153 // Send messages on this frequency so the connection isn't closed.
154 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
155 futures::pin_mut!(keepalive_timer);
156
157 // Disconnect if we don't receive messages at least this frequently.
158 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
159 futures::pin_mut!(receive_timeout);
160
161 loop {
162 let read_message = reader.read().fuse();
163 futures::pin_mut!(read_message);
164
165 loop {
166 futures::select_biased! {
167 outgoing = outgoing_rx.next().fuse() => match outgoing {
168 Some(outgoing) => {
169 if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
170 result.context("failed to write RPC message")?;
171 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
172 } else {
173 Err(anyhow!("timed out writing message"))?;
174 }
175 }
176 None => {
177 log::info!("outgoing channel closed");
178 return Ok(())
179 },
180 },
181 incoming = read_message => {
182 let incoming = incoming.context("received invalid RPC message")?;
183 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
184 if let proto::Message::Envelope(incoming) = incoming {
185 match incoming_tx.send(incoming).timeout(RECEIVE_TIMEOUT).await {
186 Some(Ok(_)) => {},
187 Some(Err(_)) => {
188 log::info!("incoming channel closed");
189 return Ok(())
190 },
191 None => Err(anyhow!("timed out processing incoming message"))?,
192 }
193 }
194 break;
195 },
196 _ = keepalive_timer => {
197 if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
198 result.context("failed to send keepalive")?;
199 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
200 } else {
201 Err(anyhow!("timed out sending keepalive"))?;
202 }
203 }
204 _ = receive_timeout => {
205 Err(anyhow!("delay between messages too long"))?
206 }
207 }
208 }
209 }
210 };
211
212 let response_channels = connection_state.response_channels.clone();
213 self.connections
214 .write()
215 .insert(connection_id, connection_state);
216
217 let incoming_rx = incoming_rx.filter_map(move |incoming| {
218 let response_channels = response_channels.clone();
219 async move {
220 if let Some(responding_to) = incoming.responding_to {
221 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
222 if let Some(tx) = channel {
223 let requester_resumed = oneshot::channel();
224 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
225 log::debug!(
226 "received RPC but request future was dropped {:?}",
227 error.0
228 );
229 }
230 let _ = requester_resumed.1.await;
231 } else {
232 log::warn!("received RPC response to unknown request {}", responding_to);
233 }
234
235 None
236 } else {
237 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
238 log::error!("unable to construct a typed envelope");
239 None
240 })
241 }
242 }
243 });
244 (connection_id, handle_io, incoming_rx.boxed())
245 }
246
247 #[cfg(any(test, feature = "test-support"))]
248 pub async fn add_test_connection(
249 self: &Arc<Self>,
250 connection: Connection,
251 executor: Arc<gpui::executor::Background>,
252 ) -> (
253 ConnectionId,
254 impl Future<Output = anyhow::Result<()>> + Send,
255 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
256 ) {
257 let executor = executor.clone();
258 self.add_connection(connection, move |duration| executor.timer(duration))
259 .await
260 }
261
262 pub fn disconnect(&self, connection_id: ConnectionId) {
263 self.connections.write().remove(&connection_id);
264 }
265
266 pub fn reset(&self) {
267 self.connections.write().clear();
268 }
269
270 pub fn request<T: RequestMessage>(
271 &self,
272 receiver_id: ConnectionId,
273 request: T,
274 ) -> impl Future<Output = Result<T::Response>> {
275 self.request_internal(None, receiver_id, request)
276 }
277
278 pub fn forward_request<T: RequestMessage>(
279 &self,
280 sender_id: ConnectionId,
281 receiver_id: ConnectionId,
282 request: T,
283 ) -> impl Future<Output = Result<T::Response>> {
284 self.request_internal(Some(sender_id), receiver_id, request)
285 }
286
287 pub fn request_internal<T: RequestMessage>(
288 &self,
289 original_sender_id: Option<ConnectionId>,
290 receiver_id: ConnectionId,
291 request: T,
292 ) -> impl Future<Output = Result<T::Response>> {
293 let (tx, rx) = oneshot::channel();
294 let send = self.connection_state(receiver_id).and_then(|connection| {
295 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
296 connection
297 .response_channels
298 .lock()
299 .as_mut()
300 .ok_or_else(|| anyhow!("connection was closed"))?
301 .insert(message_id, tx);
302 connection
303 .outgoing_tx
304 .unbounded_send(proto::Message::Envelope(request.into_envelope(
305 message_id,
306 None,
307 original_sender_id.map(|id| id.0),
308 )))
309 .map_err(|_| anyhow!("connection was closed"))?;
310 Ok(())
311 });
312 async move {
313 send?;
314 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
315 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
316 Err(anyhow!("RPC request failed - {}", error.message))
317 } else {
318 T::Response::from_envelope(response)
319 .ok_or_else(|| anyhow!("received response of the wrong type"))
320 }
321 }
322 }
323
324 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
325 let connection = self.connection_state(receiver_id)?;
326 let message_id = connection
327 .next_message_id
328 .fetch_add(1, atomic::Ordering::SeqCst);
329 connection
330 .outgoing_tx
331 .unbounded_send(proto::Message::Envelope(
332 message.into_envelope(message_id, None, None),
333 ))?;
334 Ok(())
335 }
336
337 pub fn forward_send<T: EnvelopedMessage>(
338 &self,
339 sender_id: ConnectionId,
340 receiver_id: ConnectionId,
341 message: T,
342 ) -> Result<()> {
343 let connection = self.connection_state(receiver_id)?;
344 let message_id = connection
345 .next_message_id
346 .fetch_add(1, atomic::Ordering::SeqCst);
347 connection
348 .outgoing_tx
349 .unbounded_send(proto::Message::Envelope(message.into_envelope(
350 message_id,
351 None,
352 Some(sender_id.0),
353 )))?;
354 Ok(())
355 }
356
357 pub fn respond<T: RequestMessage>(
358 &self,
359 receipt: Receipt<T>,
360 response: T::Response,
361 ) -> Result<()> {
362 let connection = self.connection_state(receipt.sender_id)?;
363 let message_id = connection
364 .next_message_id
365 .fetch_add(1, atomic::Ordering::SeqCst);
366 connection
367 .outgoing_tx
368 .unbounded_send(proto::Message::Envelope(response.into_envelope(
369 message_id,
370 Some(receipt.message_id),
371 None,
372 )))?;
373 Ok(())
374 }
375
376 pub fn respond_with_error<T: RequestMessage>(
377 &self,
378 receipt: Receipt<T>,
379 response: proto::Error,
380 ) -> Result<()> {
381 let connection = self.connection_state(receipt.sender_id)?;
382 let message_id = connection
383 .next_message_id
384 .fetch_add(1, atomic::Ordering::SeqCst);
385 connection
386 .outgoing_tx
387 .unbounded_send(proto::Message::Envelope(response.into_envelope(
388 message_id,
389 Some(receipt.message_id),
390 None,
391 )))?;
392 Ok(())
393 }
394
395 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
396 let connections = self.connections.read();
397 let connection = connections
398 .get(&connection_id)
399 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
400 Ok(connection.clone())
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use crate::TypedEnvelope;
408 use async_tungstenite::tungstenite::Message as WebSocketMessage;
409 use gpui::TestAppContext;
410
411 #[gpui::test(iterations = 50)]
412 async fn test_request_response(cx: &mut TestAppContext) {
413 let executor = cx.foreground();
414
415 // create 2 clients connected to 1 server
416 let server = Peer::new();
417 let client1 = Peer::new();
418 let client2 = Peer::new();
419
420 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
421 Connection::in_memory(cx.background());
422 let (client1_conn_id, io_task1, client1_incoming) = client1
423 .add_test_connection(client1_to_server_conn, cx.background())
424 .await;
425 let (_, io_task2, server_incoming1) = server
426 .add_test_connection(server_to_client_1_conn, cx.background())
427 .await;
428
429 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
430 Connection::in_memory(cx.background());
431 let (client2_conn_id, io_task3, client2_incoming) = client2
432 .add_test_connection(client2_to_server_conn, cx.background())
433 .await;
434 let (_, io_task4, server_incoming2) = server
435 .add_test_connection(server_to_client_2_conn, cx.background())
436 .await;
437
438 executor.spawn(io_task1).detach();
439 executor.spawn(io_task2).detach();
440 executor.spawn(io_task3).detach();
441 executor.spawn(io_task4).detach();
442 executor
443 .spawn(handle_messages(server_incoming1, server.clone()))
444 .detach();
445 executor
446 .spawn(handle_messages(client1_incoming, client1.clone()))
447 .detach();
448 executor
449 .spawn(handle_messages(server_incoming2, server.clone()))
450 .detach();
451 executor
452 .spawn(handle_messages(client2_incoming, client2.clone()))
453 .detach();
454
455 assert_eq!(
456 client1
457 .request(client1_conn_id, proto::Ping {},)
458 .await
459 .unwrap(),
460 proto::Ack {}
461 );
462
463 assert_eq!(
464 client2
465 .request(client2_conn_id, proto::Ping {},)
466 .await
467 .unwrap(),
468 proto::Ack {}
469 );
470
471 assert_eq!(
472 client1
473 .request(client1_conn_id, proto::Test { id: 1 },)
474 .await
475 .unwrap(),
476 proto::Test { id: 1 }
477 );
478
479 assert_eq!(
480 client2
481 .request(client2_conn_id, proto::Test { id: 2 })
482 .await
483 .unwrap(),
484 proto::Test { id: 2 }
485 );
486
487 client1.disconnect(client1_conn_id);
488 client2.disconnect(client1_conn_id);
489
490 async fn handle_messages(
491 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
492 peer: Arc<Peer>,
493 ) -> Result<()> {
494 while let Some(envelope) = messages.next().await {
495 let envelope = envelope.into_any();
496 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
497 let receipt = envelope.receipt();
498 peer.respond(receipt, proto::Ack {})?
499 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
500 {
501 peer.respond(envelope.receipt(), envelope.payload.clone())?
502 } else {
503 panic!("unknown message type");
504 }
505 }
506
507 Ok(())
508 }
509 }
510
511 #[gpui::test(iterations = 50)]
512 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
513 let executor = cx.foreground();
514 let server = Peer::new();
515 let client = Peer::new();
516
517 let (client_to_server_conn, server_to_client_conn, _kill) =
518 Connection::in_memory(cx.background());
519 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
520 .add_test_connection(client_to_server_conn, cx.background())
521 .await;
522 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
523 .add_test_connection(server_to_client_conn, cx.background())
524 .await;
525
526 executor.spawn(io_task1).detach();
527 executor.spawn(io_task2).detach();
528
529 executor
530 .spawn(async move {
531 let request = server_incoming
532 .next()
533 .await
534 .unwrap()
535 .into_any()
536 .downcast::<TypedEnvelope<proto::Ping>>()
537 .unwrap();
538
539 server
540 .send(
541 server_to_client_conn_id,
542 proto::Error {
543 message: "message 1".to_string(),
544 },
545 )
546 .unwrap();
547 server
548 .send(
549 server_to_client_conn_id,
550 proto::Error {
551 message: "message 2".to_string(),
552 },
553 )
554 .unwrap();
555 server.respond(request.receipt(), proto::Ack {}).unwrap();
556
557 // Prevent the connection from being dropped
558 server_incoming.next().await;
559 })
560 .detach();
561
562 let events = Arc::new(Mutex::new(Vec::new()));
563
564 let response = client.request(client_to_server_conn_id, proto::Ping {});
565 let response_task = executor.spawn({
566 let events = events.clone();
567 async move {
568 response.await.unwrap();
569 events.lock().push("response".to_string());
570 }
571 });
572
573 executor
574 .spawn({
575 let events = events.clone();
576 async move {
577 let incoming1 = client_incoming
578 .next()
579 .await
580 .unwrap()
581 .into_any()
582 .downcast::<TypedEnvelope<proto::Error>>()
583 .unwrap();
584 events.lock().push(incoming1.payload.message);
585 let incoming2 = client_incoming
586 .next()
587 .await
588 .unwrap()
589 .into_any()
590 .downcast::<TypedEnvelope<proto::Error>>()
591 .unwrap();
592 events.lock().push(incoming2.payload.message);
593
594 // Prevent the connection from being dropped
595 client_incoming.next().await;
596 }
597 })
598 .detach();
599
600 response_task.await;
601 assert_eq!(
602 &*events.lock(),
603 &[
604 "message 1".to_string(),
605 "message 2".to_string(),
606 "response".to_string()
607 ]
608 );
609 }
610
611 #[gpui::test(iterations = 50)]
612 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
613 let executor = cx.foreground();
614 let server = Peer::new();
615 let client = Peer::new();
616
617 let (client_to_server_conn, server_to_client_conn, _kill) =
618 Connection::in_memory(cx.background());
619 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
620 .add_test_connection(client_to_server_conn, cx.background())
621 .await;
622 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
623 .add_test_connection(server_to_client_conn, cx.background())
624 .await;
625
626 executor.spawn(io_task1).detach();
627 executor.spawn(io_task2).detach();
628
629 executor
630 .spawn(async move {
631 let request1 = server_incoming
632 .next()
633 .await
634 .unwrap()
635 .into_any()
636 .downcast::<TypedEnvelope<proto::Ping>>()
637 .unwrap();
638 let request2 = server_incoming
639 .next()
640 .await
641 .unwrap()
642 .into_any()
643 .downcast::<TypedEnvelope<proto::Ping>>()
644 .unwrap();
645
646 server
647 .send(
648 server_to_client_conn_id,
649 proto::Error {
650 message: "message 1".to_string(),
651 },
652 )
653 .unwrap();
654 server
655 .send(
656 server_to_client_conn_id,
657 proto::Error {
658 message: "message 2".to_string(),
659 },
660 )
661 .unwrap();
662 server.respond(request1.receipt(), proto::Ack {}).unwrap();
663 server.respond(request2.receipt(), proto::Ack {}).unwrap();
664
665 // Prevent the connection from being dropped
666 server_incoming.next().await;
667 })
668 .detach();
669
670 let events = Arc::new(Mutex::new(Vec::new()));
671
672 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
673 let request1_task = executor.spawn(request1);
674 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
675 let request2_task = executor.spawn({
676 let events = events.clone();
677 async move {
678 request2.await.unwrap();
679 events.lock().push("response 2".to_string());
680 }
681 });
682
683 executor
684 .spawn({
685 let events = events.clone();
686 async move {
687 let incoming1 = client_incoming
688 .next()
689 .await
690 .unwrap()
691 .into_any()
692 .downcast::<TypedEnvelope<proto::Error>>()
693 .unwrap();
694 events.lock().push(incoming1.payload.message);
695 let incoming2 = client_incoming
696 .next()
697 .await
698 .unwrap()
699 .into_any()
700 .downcast::<TypedEnvelope<proto::Error>>()
701 .unwrap();
702 events.lock().push(incoming2.payload.message);
703
704 // Prevent the connection from being dropped
705 client_incoming.next().await;
706 }
707 })
708 .detach();
709
710 // Allow the request to make some progress before dropping it.
711 cx.background().simulate_random_delay().await;
712 drop(request1_task);
713
714 request2_task.await;
715 assert_eq!(
716 &*events.lock(),
717 &[
718 "message 1".to_string(),
719 "message 2".to_string(),
720 "response 2".to_string()
721 ]
722 );
723 }
724
725 #[gpui::test(iterations = 50)]
726 async fn test_disconnect(cx: &mut TestAppContext) {
727 let executor = cx.foreground();
728
729 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
730
731 let client = Peer::new();
732 let (connection_id, io_handler, mut incoming) = client
733 .add_test_connection(client_conn, cx.background())
734 .await;
735
736 let (io_ended_tx, io_ended_rx) = oneshot::channel();
737 executor
738 .spawn(async move {
739 io_handler.await.ok();
740 io_ended_tx.send(()).unwrap();
741 })
742 .detach();
743
744 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
745 executor
746 .spawn(async move {
747 incoming.next().await;
748 messages_ended_tx.send(()).unwrap();
749 })
750 .detach();
751
752 client.disconnect(connection_id);
753
754 let _ = io_ended_rx.await;
755 let _ = messages_ended_rx.await;
756 assert!(server_conn
757 .send(WebSocketMessage::Binary(vec![]))
758 .await
759 .is_err());
760 }
761
762 #[gpui::test(iterations = 50)]
763 async fn test_io_error(cx: &mut TestAppContext) {
764 let executor = cx.foreground();
765 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
766
767 let client = Peer::new();
768 let (connection_id, io_handler, mut incoming) = client
769 .add_test_connection(client_conn, cx.background())
770 .await;
771 executor.spawn(io_handler).detach();
772 executor
773 .spawn(async move { incoming.next().await })
774 .detach();
775
776 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
777 let _request = server_conn.rx.next().await.unwrap().unwrap();
778
779 drop(server_conn);
780 assert_eq!(
781 response.await.unwrap_err().to_string(),
782 "connection was closed"
783 );
784 }
785}