1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::BoxStream,
10 FutureExt, SinkExt, StreamExt,
11};
12use parking_lot::{Mutex, RwLock};
13use smol_timeout::TimeoutExt;
14use std::sync::atomic::Ordering::SeqCst;
15use std::{
16 fmt,
17 future::Future,
18 marker::PhantomData,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23 time::Duration,
24};
25
26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub struct ConnectionId(pub u32);
28
29impl fmt::Display for ConnectionId {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 self.0.fmt(f)
32 }
33}
34
35#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
36pub struct PeerId(pub u32);
37
38impl fmt::Display for PeerId {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 self.0.fmt(f)
41 }
42}
43
44pub struct Receipt<T> {
45 pub sender_id: ConnectionId,
46 pub message_id: u32,
47 payload_type: PhantomData<T>,
48}
49
50impl<T> Clone for Receipt<T> {
51 fn clone(&self) -> Self {
52 Self {
53 sender_id: self.sender_id,
54 message_id: self.message_id,
55 payload_type: PhantomData,
56 }
57 }
58}
59
60impl<T> Copy for Receipt<T> {}
61
62pub struct TypedEnvelope<T> {
63 pub sender_id: ConnectionId,
64 pub original_sender_id: Option<PeerId>,
65 pub message_id: u32,
66 pub payload: T,
67}
68
69impl<T> TypedEnvelope<T> {
70 pub fn original_sender_id(&self) -> Result<PeerId> {
71 self.original_sender_id
72 .ok_or_else(|| anyhow!("missing original_sender_id"))
73 }
74}
75
76impl<T: RequestMessage> TypedEnvelope<T> {
77 pub fn receipt(&self) -> Receipt<T> {
78 Receipt {
79 sender_id: self.sender_id,
80 message_id: self.message_id,
81 payload_type: PhantomData,
82 }
83 }
84}
85
86pub struct Peer {
87 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
88 next_connection_id: AtomicU32,
89}
90
91#[derive(Clone)]
92pub struct ConnectionState {
93 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
94 next_message_id: Arc<AtomicU32>,
95 response_channels:
96 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
97}
98
99const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
100const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
101pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
102
103impl Peer {
104 pub fn new() -> Arc<Self> {
105 Arc::new(Self {
106 connections: Default::default(),
107 next_connection_id: Default::default(),
108 })
109 }
110
111 pub async fn add_connection<F, Fut, Out>(
112 self: &Arc<Self>,
113 connection: Connection,
114 create_timer: F,
115 ) -> (
116 ConnectionId,
117 impl Future<Output = anyhow::Result<()>> + Send,
118 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
119 )
120 where
121 F: Send + Fn(Duration) -> Fut,
122 Fut: Send + Future<Output = Out>,
123 Out: Send,
124 {
125 // For outgoing messages, use an unbounded channel so that application code
126 // can always send messages without yielding. For incoming messages, use a
127 // bounded channel so that other peers will receive backpressure if they send
128 // messages faster than this peer can process them.
129 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
130 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
131
132 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
133 let connection_state = ConnectionState {
134 outgoing_tx: outgoing_tx.clone(),
135 next_message_id: Default::default(),
136 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
137 };
138 let mut writer = MessageStream::new(connection.tx);
139 let mut reader = MessageStream::new(connection.rx);
140
141 let this = self.clone();
142 let response_channels = connection_state.response_channels.clone();
143 let handle_io = async move {
144 let _end_connection = util::defer(|| {
145 response_channels.lock().take();
146 this.connections.write().remove(&connection_id);
147 });
148
149 // Send messages on this frequency so the connection isn't closed.
150 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
151 futures::pin_mut!(keepalive_timer);
152
153 // Disconnect if we don't receive messages at least this frequently.
154 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
155 futures::pin_mut!(receive_timeout);
156
157 loop {
158 let read_message = reader.read().fuse();
159 futures::pin_mut!(read_message);
160
161 loop {
162 futures::select_biased! {
163 outgoing = outgoing_rx.next().fuse() => match outgoing {
164 Some(outgoing) => {
165 if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
166 result.context("failed to write RPC message")?;
167 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
168 } else {
169 Err(anyhow!("timed out writing message"))?;
170 }
171 }
172 None => return Ok(()),
173 },
174 incoming = read_message => {
175 let incoming = incoming.context("received invalid RPC message")?;
176 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
177 if let proto::Message::Envelope(incoming) = incoming {
178 match incoming_tx.send(incoming).timeout(RECEIVE_TIMEOUT).await {
179 Some(Ok(_)) => {},
180 Some(Err(_)) => return Ok(()),
181 None => Err(anyhow!("timed out processing incoming message"))?,
182 }
183 }
184 break;
185 },
186 _ = keepalive_timer => {
187 if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
188 result.context("failed to send keepalive")?;
189 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
190 } else {
191 Err(anyhow!("timed out sending keepalive"))?;
192 }
193 }
194 _ = receive_timeout => {
195 Err(anyhow!("delay between messages too long"))?
196 }
197 }
198 }
199 }
200 };
201
202 let response_channels = connection_state.response_channels.clone();
203 self.connections
204 .write()
205 .insert(connection_id, connection_state);
206
207 let incoming_rx = incoming_rx.filter_map(move |incoming| {
208 let response_channels = response_channels.clone();
209 async move {
210 if let Some(responding_to) = incoming.responding_to {
211 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
212 if let Some(tx) = channel {
213 let requester_resumed = oneshot::channel();
214 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
215 log::debug!(
216 "received RPC but request future was dropped {:?}",
217 error.0
218 );
219 }
220 let _ = requester_resumed.1.await;
221 } else {
222 log::warn!("received RPC response to unknown request {}", responding_to);
223 }
224
225 None
226 } else {
227 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
228 log::error!("unable to construct a typed envelope");
229 None
230 })
231 }
232 }
233 });
234 (connection_id, handle_io, incoming_rx.boxed())
235 }
236
237 #[cfg(any(test, feature = "test-support"))]
238 pub async fn add_test_connection(
239 self: &Arc<Self>,
240 connection: Connection,
241 executor: Arc<gpui::executor::Background>,
242 ) -> (
243 ConnectionId,
244 impl Future<Output = anyhow::Result<()>> + Send,
245 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
246 ) {
247 let executor = executor.clone();
248 self.add_connection(connection, move |duration| executor.timer(duration))
249 .await
250 }
251
252 pub fn disconnect(&self, connection_id: ConnectionId) {
253 self.connections.write().remove(&connection_id);
254 }
255
256 pub fn reset(&self) {
257 self.connections.write().clear();
258 }
259
260 pub fn request<T: RequestMessage>(
261 &self,
262 receiver_id: ConnectionId,
263 request: T,
264 ) -> impl Future<Output = Result<T::Response>> {
265 self.request_internal(None, receiver_id, request)
266 }
267
268 pub fn forward_request<T: RequestMessage>(
269 &self,
270 sender_id: ConnectionId,
271 receiver_id: ConnectionId,
272 request: T,
273 ) -> impl Future<Output = Result<T::Response>> {
274 self.request_internal(Some(sender_id), receiver_id, request)
275 }
276
277 pub fn request_internal<T: RequestMessage>(
278 &self,
279 original_sender_id: Option<ConnectionId>,
280 receiver_id: ConnectionId,
281 request: T,
282 ) -> impl Future<Output = Result<T::Response>> {
283 let (tx, rx) = oneshot::channel();
284 let send = self.connection_state(receiver_id).and_then(|connection| {
285 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
286 connection
287 .response_channels
288 .lock()
289 .as_mut()
290 .ok_or_else(|| anyhow!("connection was closed"))?
291 .insert(message_id, tx);
292 connection
293 .outgoing_tx
294 .unbounded_send(proto::Message::Envelope(request.into_envelope(
295 message_id,
296 None,
297 original_sender_id.map(|id| id.0),
298 )))
299 .map_err(|_| anyhow!("connection was closed"))?;
300 Ok(())
301 });
302 async move {
303 send?;
304 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
305 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
306 Err(anyhow!("RPC request failed - {}", error.message))
307 } else {
308 T::Response::from_envelope(response)
309 .ok_or_else(|| anyhow!("received response of the wrong type"))
310 }
311 }
312 }
313
314 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
315 let connection = self.connection_state(receiver_id)?;
316 let message_id = connection
317 .next_message_id
318 .fetch_add(1, atomic::Ordering::SeqCst);
319 connection
320 .outgoing_tx
321 .unbounded_send(proto::Message::Envelope(
322 message.into_envelope(message_id, None, None),
323 ))?;
324 Ok(())
325 }
326
327 pub fn forward_send<T: EnvelopedMessage>(
328 &self,
329 sender_id: ConnectionId,
330 receiver_id: ConnectionId,
331 message: T,
332 ) -> Result<()> {
333 let connection = self.connection_state(receiver_id)?;
334 let message_id = connection
335 .next_message_id
336 .fetch_add(1, atomic::Ordering::SeqCst);
337 connection
338 .outgoing_tx
339 .unbounded_send(proto::Message::Envelope(message.into_envelope(
340 message_id,
341 None,
342 Some(sender_id.0),
343 )))?;
344 Ok(())
345 }
346
347 pub fn respond<T: RequestMessage>(
348 &self,
349 receipt: Receipt<T>,
350 response: T::Response,
351 ) -> Result<()> {
352 let connection = self.connection_state(receipt.sender_id)?;
353 let message_id = connection
354 .next_message_id
355 .fetch_add(1, atomic::Ordering::SeqCst);
356 connection
357 .outgoing_tx
358 .unbounded_send(proto::Message::Envelope(response.into_envelope(
359 message_id,
360 Some(receipt.message_id),
361 None,
362 )))?;
363 Ok(())
364 }
365
366 pub fn respond_with_error<T: RequestMessage>(
367 &self,
368 receipt: Receipt<T>,
369 response: proto::Error,
370 ) -> Result<()> {
371 let connection = self.connection_state(receipt.sender_id)?;
372 let message_id = connection
373 .next_message_id
374 .fetch_add(1, atomic::Ordering::SeqCst);
375 connection
376 .outgoing_tx
377 .unbounded_send(proto::Message::Envelope(response.into_envelope(
378 message_id,
379 Some(receipt.message_id),
380 None,
381 )))?;
382 Ok(())
383 }
384
385 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
386 let connections = self.connections.read();
387 let connection = connections
388 .get(&connection_id)
389 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
390 Ok(connection.clone())
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::TypedEnvelope;
398 use async_tungstenite::tungstenite::Message as WebSocketMessage;
399 use gpui::TestAppContext;
400
401 #[gpui::test(iterations = 50)]
402 async fn test_request_response(cx: &mut TestAppContext) {
403 let executor = cx.foreground();
404
405 // create 2 clients connected to 1 server
406 let server = Peer::new();
407 let client1 = Peer::new();
408 let client2 = Peer::new();
409
410 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
411 Connection::in_memory(cx.background());
412 let (client1_conn_id, io_task1, client1_incoming) = client1
413 .add_test_connection(client1_to_server_conn, cx.background())
414 .await;
415 let (_, io_task2, server_incoming1) = server
416 .add_test_connection(server_to_client_1_conn, cx.background())
417 .await;
418
419 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
420 Connection::in_memory(cx.background());
421 let (client2_conn_id, io_task3, client2_incoming) = client2
422 .add_test_connection(client2_to_server_conn, cx.background())
423 .await;
424 let (_, io_task4, server_incoming2) = server
425 .add_test_connection(server_to_client_2_conn, cx.background())
426 .await;
427
428 executor.spawn(io_task1).detach();
429 executor.spawn(io_task2).detach();
430 executor.spawn(io_task3).detach();
431 executor.spawn(io_task4).detach();
432 executor
433 .spawn(handle_messages(server_incoming1, server.clone()))
434 .detach();
435 executor
436 .spawn(handle_messages(client1_incoming, client1.clone()))
437 .detach();
438 executor
439 .spawn(handle_messages(server_incoming2, server.clone()))
440 .detach();
441 executor
442 .spawn(handle_messages(client2_incoming, client2.clone()))
443 .detach();
444
445 assert_eq!(
446 client1
447 .request(client1_conn_id, proto::Ping {},)
448 .await
449 .unwrap(),
450 proto::Ack {}
451 );
452
453 assert_eq!(
454 client2
455 .request(client2_conn_id, proto::Ping {},)
456 .await
457 .unwrap(),
458 proto::Ack {}
459 );
460
461 assert_eq!(
462 client1
463 .request(client1_conn_id, proto::Test { id: 1 },)
464 .await
465 .unwrap(),
466 proto::Test { id: 1 }
467 );
468
469 assert_eq!(
470 client2
471 .request(client2_conn_id, proto::Test { id: 2 })
472 .await
473 .unwrap(),
474 proto::Test { id: 2 }
475 );
476
477 client1.disconnect(client1_conn_id);
478 client2.disconnect(client1_conn_id);
479
480 async fn handle_messages(
481 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
482 peer: Arc<Peer>,
483 ) -> Result<()> {
484 while let Some(envelope) = messages.next().await {
485 let envelope = envelope.into_any();
486 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
487 let receipt = envelope.receipt();
488 peer.respond(receipt, proto::Ack {})?
489 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
490 {
491 peer.respond(envelope.receipt(), envelope.payload.clone())?
492 } else {
493 panic!("unknown message type");
494 }
495 }
496
497 Ok(())
498 }
499 }
500
501 #[gpui::test(iterations = 50)]
502 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
503 let executor = cx.foreground();
504 let server = Peer::new();
505 let client = Peer::new();
506
507 let (client_to_server_conn, server_to_client_conn, _kill) =
508 Connection::in_memory(cx.background());
509 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
510 .add_test_connection(client_to_server_conn, cx.background())
511 .await;
512 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
513 .add_test_connection(server_to_client_conn, cx.background())
514 .await;
515
516 executor.spawn(io_task1).detach();
517 executor.spawn(io_task2).detach();
518
519 executor
520 .spawn(async move {
521 let request = server_incoming
522 .next()
523 .await
524 .unwrap()
525 .into_any()
526 .downcast::<TypedEnvelope<proto::Ping>>()
527 .unwrap();
528
529 server
530 .send(
531 server_to_client_conn_id,
532 proto::Error {
533 message: "message 1".to_string(),
534 },
535 )
536 .unwrap();
537 server
538 .send(
539 server_to_client_conn_id,
540 proto::Error {
541 message: "message 2".to_string(),
542 },
543 )
544 .unwrap();
545 server.respond(request.receipt(), proto::Ack {}).unwrap();
546
547 // Prevent the connection from being dropped
548 server_incoming.next().await;
549 })
550 .detach();
551
552 let events = Arc::new(Mutex::new(Vec::new()));
553
554 let response = client.request(client_to_server_conn_id, proto::Ping {});
555 let response_task = executor.spawn({
556 let events = events.clone();
557 async move {
558 response.await.unwrap();
559 events.lock().push("response".to_string());
560 }
561 });
562
563 executor
564 .spawn({
565 let events = events.clone();
566 async move {
567 let incoming1 = client_incoming
568 .next()
569 .await
570 .unwrap()
571 .into_any()
572 .downcast::<TypedEnvelope<proto::Error>>()
573 .unwrap();
574 events.lock().push(incoming1.payload.message);
575 let incoming2 = client_incoming
576 .next()
577 .await
578 .unwrap()
579 .into_any()
580 .downcast::<TypedEnvelope<proto::Error>>()
581 .unwrap();
582 events.lock().push(incoming2.payload.message);
583
584 // Prevent the connection from being dropped
585 client_incoming.next().await;
586 }
587 })
588 .detach();
589
590 response_task.await;
591 assert_eq!(
592 &*events.lock(),
593 &[
594 "message 1".to_string(),
595 "message 2".to_string(),
596 "response".to_string()
597 ]
598 );
599 }
600
601 #[gpui::test(iterations = 50)]
602 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
603 let executor = cx.foreground();
604 let server = Peer::new();
605 let client = Peer::new();
606
607 let (client_to_server_conn, server_to_client_conn, _kill) =
608 Connection::in_memory(cx.background());
609 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
610 .add_test_connection(client_to_server_conn, cx.background())
611 .await;
612 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
613 .add_test_connection(server_to_client_conn, cx.background())
614 .await;
615
616 executor.spawn(io_task1).detach();
617 executor.spawn(io_task2).detach();
618
619 executor
620 .spawn(async move {
621 let request1 = server_incoming
622 .next()
623 .await
624 .unwrap()
625 .into_any()
626 .downcast::<TypedEnvelope<proto::Ping>>()
627 .unwrap();
628 let request2 = server_incoming
629 .next()
630 .await
631 .unwrap()
632 .into_any()
633 .downcast::<TypedEnvelope<proto::Ping>>()
634 .unwrap();
635
636 server
637 .send(
638 server_to_client_conn_id,
639 proto::Error {
640 message: "message 1".to_string(),
641 },
642 )
643 .unwrap();
644 server
645 .send(
646 server_to_client_conn_id,
647 proto::Error {
648 message: "message 2".to_string(),
649 },
650 )
651 .unwrap();
652 server.respond(request1.receipt(), proto::Ack {}).unwrap();
653 server.respond(request2.receipt(), proto::Ack {}).unwrap();
654
655 // Prevent the connection from being dropped
656 server_incoming.next().await;
657 })
658 .detach();
659
660 let events = Arc::new(Mutex::new(Vec::new()));
661
662 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
663 let request1_task = executor.spawn(request1);
664 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
665 let request2_task = executor.spawn({
666 let events = events.clone();
667 async move {
668 request2.await.unwrap();
669 events.lock().push("response 2".to_string());
670 }
671 });
672
673 executor
674 .spawn({
675 let events = events.clone();
676 async move {
677 let incoming1 = client_incoming
678 .next()
679 .await
680 .unwrap()
681 .into_any()
682 .downcast::<TypedEnvelope<proto::Error>>()
683 .unwrap();
684 events.lock().push(incoming1.payload.message);
685 let incoming2 = client_incoming
686 .next()
687 .await
688 .unwrap()
689 .into_any()
690 .downcast::<TypedEnvelope<proto::Error>>()
691 .unwrap();
692 events.lock().push(incoming2.payload.message);
693
694 // Prevent the connection from being dropped
695 client_incoming.next().await;
696 }
697 })
698 .detach();
699
700 // Allow the request to make some progress before dropping it.
701 cx.background().simulate_random_delay().await;
702 drop(request1_task);
703
704 request2_task.await;
705 assert_eq!(
706 &*events.lock(),
707 &[
708 "message 1".to_string(),
709 "message 2".to_string(),
710 "response 2".to_string()
711 ]
712 );
713 }
714
715 #[gpui::test(iterations = 50)]
716 async fn test_disconnect(cx: &mut TestAppContext) {
717 let executor = cx.foreground();
718
719 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
720
721 let client = Peer::new();
722 let (connection_id, io_handler, mut incoming) = client
723 .add_test_connection(client_conn, cx.background())
724 .await;
725
726 let (io_ended_tx, io_ended_rx) = oneshot::channel();
727 executor
728 .spawn(async move {
729 io_handler.await.ok();
730 io_ended_tx.send(()).unwrap();
731 })
732 .detach();
733
734 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
735 executor
736 .spawn(async move {
737 incoming.next().await;
738 messages_ended_tx.send(()).unwrap();
739 })
740 .detach();
741
742 client.disconnect(connection_id);
743
744 let _ = io_ended_rx.await;
745 let _ = messages_ended_rx.await;
746 assert!(server_conn
747 .send(WebSocketMessage::Binary(vec![]))
748 .await
749 .is_err());
750 }
751
752 #[gpui::test(iterations = 50)]
753 async fn test_io_error(cx: &mut TestAppContext) {
754 let executor = cx.foreground();
755 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
756
757 let client = Peer::new();
758 let (connection_id, io_handler, mut incoming) = client
759 .add_test_connection(client_conn, cx.background())
760 .await;
761 executor.spawn(io_handler).detach();
762 executor
763 .spawn(async move { incoming.next().await })
764 .detach();
765
766 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
767 let _request = server_conn.rx.next().await.unwrap().unwrap();
768
769 drop(server_conn);
770 assert_eq!(
771 response.await.unwrap_err().to_string(),
772 "connection was closed"
773 );
774 }
775}