1use super::{
2 proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage},
3 Connection,
4};
5use anyhow::{anyhow, Context, Result};
6use collections::HashMap;
7use futures::{
8 channel::{mpsc, oneshot},
9 stream::BoxStream,
10 FutureExt, SinkExt, StreamExt,
11};
12use parking_lot::{Mutex, RwLock};
13use smol_timeout::TimeoutExt;
14use std::sync::atomic::Ordering::SeqCst;
15use std::{
16 fmt,
17 future::Future,
18 marker::PhantomData,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23 time::Duration,
24};
25
26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub struct ConnectionId(pub u32);
28
29impl fmt::Display for ConnectionId {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 self.0.fmt(f)
32 }
33}
34
35#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
36pub struct PeerId(pub u32);
37
38impl fmt::Display for PeerId {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 self.0.fmt(f)
41 }
42}
43
44pub struct Receipt<T> {
45 pub sender_id: ConnectionId,
46 pub message_id: u32,
47 payload_type: PhantomData<T>,
48}
49
50impl<T> Clone for Receipt<T> {
51 fn clone(&self) -> Self {
52 Self {
53 sender_id: self.sender_id,
54 message_id: self.message_id,
55 payload_type: PhantomData,
56 }
57 }
58}
59
60impl<T> Copy for Receipt<T> {}
61
62pub struct TypedEnvelope<T> {
63 pub sender_id: ConnectionId,
64 pub original_sender_id: Option<PeerId>,
65 pub message_id: u32,
66 pub payload: T,
67}
68
69impl<T> TypedEnvelope<T> {
70 pub fn original_sender_id(&self) -> Result<PeerId> {
71 self.original_sender_id
72 .ok_or_else(|| anyhow!("missing original_sender_id"))
73 }
74}
75
76impl<T: RequestMessage> TypedEnvelope<T> {
77 pub fn receipt(&self) -> Receipt<T> {
78 Receipt {
79 sender_id: self.sender_id,
80 message_id: self.message_id,
81 payload_type: PhantomData,
82 }
83 }
84}
85
86pub struct Peer {
87 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
88 next_connection_id: AtomicU32,
89}
90
91#[derive(Clone)]
92pub struct ConnectionState {
93 outgoing_tx: mpsc::UnboundedSender<proto::Message>,
94 next_message_id: Arc<AtomicU32>,
95 response_channels:
96 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, oneshot::Sender<()>)>>>>>,
97}
98
99const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
100const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
101pub const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5);
102
103impl Peer {
104 pub fn new() -> Arc<Self> {
105 Arc::new(Self {
106 connections: Default::default(),
107 next_connection_id: Default::default(),
108 })
109 }
110
111 pub async fn add_connection<F, Fut, Out>(
112 self: &Arc<Self>,
113 connection: Connection,
114 create_timer: F,
115 ) -> (
116 ConnectionId,
117 impl Future<Output = anyhow::Result<()>> + Send,
118 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
119 )
120 where
121 F: Send + Fn(Duration) -> Fut,
122 Fut: Send + Future<Output = Out>,
123 Out: Send,
124 {
125 // For outgoing messages, use an unbounded channel so that application code
126 // can always send messages without yielding. For incoming messages, use a
127 // bounded channel so that other peers will receive backpressure if they send
128 // messages faster than this peer can process them.
129 #[cfg(any(test, feature = "test-support"))]
130 const INCOMING_BUFFER_SIZE: usize = 1;
131 #[cfg(not(any(test, feature = "test-support")))]
132 const INCOMING_BUFFER_SIZE: usize = 64;
133 let (mut incoming_tx, incoming_rx) = mpsc::channel(INCOMING_BUFFER_SIZE);
134 let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
135
136 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
137 let connection_state = ConnectionState {
138 outgoing_tx: outgoing_tx.clone(),
139 next_message_id: Default::default(),
140 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
141 };
142 let mut writer = MessageStream::new(connection.tx);
143 let mut reader = MessageStream::new(connection.rx);
144
145 let this = self.clone();
146 let response_channels = connection_state.response_channels.clone();
147 let handle_io = async move {
148 let _end_connection = util::defer(|| {
149 response_channels.lock().take();
150 this.connections.write().remove(&connection_id);
151 });
152
153 // Send messages on this frequency so the connection isn't closed.
154 let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
155 futures::pin_mut!(keepalive_timer);
156
157 // Disconnect if we don't receive messages at least this frequently.
158 let receive_timeout = create_timer(RECEIVE_TIMEOUT).fuse();
159 futures::pin_mut!(receive_timeout);
160
161 loop {
162 let read_message = reader.read().fuse();
163 futures::pin_mut!(read_message);
164
165 loop {
166 futures::select_biased! {
167 outgoing = outgoing_rx.next().fuse() => match outgoing {
168 Some(outgoing) => {
169 if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
170 result.context("failed to write RPC message")?;
171 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
172 } else {
173 Err(anyhow!("timed out writing message"))?;
174 }
175 }
176 None => return Ok(()),
177 },
178 incoming = read_message => {
179 let incoming = incoming.context("received invalid RPC message")?;
180 receive_timeout.set(create_timer(RECEIVE_TIMEOUT).fuse());
181 if let proto::Message::Envelope(incoming) = incoming {
182 match incoming_tx.send(incoming).timeout(RECEIVE_TIMEOUT).await {
183 Some(Ok(_)) => {},
184 Some(Err(_)) => return Ok(()),
185 None => Err(anyhow!("timed out processing incoming message"))?,
186 }
187 }
188 break;
189 },
190 _ = keepalive_timer => {
191 if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
192 result.context("failed to send keepalive")?;
193 keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
194 } else {
195 Err(anyhow!("timed out sending keepalive"))?;
196 }
197 }
198 _ = receive_timeout => {
199 Err(anyhow!("delay between messages too long"))?
200 }
201 }
202 }
203 }
204 };
205
206 let response_channels = connection_state.response_channels.clone();
207 self.connections
208 .write()
209 .insert(connection_id, connection_state);
210
211 let incoming_rx = incoming_rx.filter_map(move |incoming| {
212 let response_channels = response_channels.clone();
213 async move {
214 if let Some(responding_to) = incoming.responding_to {
215 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
216 if let Some(tx) = channel {
217 let requester_resumed = oneshot::channel();
218 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
219 log::debug!(
220 "received RPC but request future was dropped {:?}",
221 error.0
222 );
223 }
224 let _ = requester_resumed.1.await;
225 } else {
226 log::warn!("received RPC response to unknown request {}", responding_to);
227 }
228
229 None
230 } else {
231 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
232 log::error!("unable to construct a typed envelope");
233 None
234 })
235 }
236 }
237 });
238 (connection_id, handle_io, incoming_rx.boxed())
239 }
240
241 #[cfg(any(test, feature = "test-support"))]
242 pub async fn add_test_connection(
243 self: &Arc<Self>,
244 connection: Connection,
245 executor: Arc<gpui::executor::Background>,
246 ) -> (
247 ConnectionId,
248 impl Future<Output = anyhow::Result<()>> + Send,
249 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
250 ) {
251 let executor = executor.clone();
252 self.add_connection(connection, move |duration| executor.timer(duration))
253 .await
254 }
255
256 pub fn disconnect(&self, connection_id: ConnectionId) {
257 self.connections.write().remove(&connection_id);
258 }
259
260 pub fn reset(&self) {
261 self.connections.write().clear();
262 }
263
264 pub fn request<T: RequestMessage>(
265 &self,
266 receiver_id: ConnectionId,
267 request: T,
268 ) -> impl Future<Output = Result<T::Response>> {
269 self.request_internal(None, receiver_id, request)
270 }
271
272 pub fn forward_request<T: RequestMessage>(
273 &self,
274 sender_id: ConnectionId,
275 receiver_id: ConnectionId,
276 request: T,
277 ) -> impl Future<Output = Result<T::Response>> {
278 self.request_internal(Some(sender_id), receiver_id, request)
279 }
280
281 pub fn request_internal<T: RequestMessage>(
282 &self,
283 original_sender_id: Option<ConnectionId>,
284 receiver_id: ConnectionId,
285 request: T,
286 ) -> impl Future<Output = Result<T::Response>> {
287 let (tx, rx) = oneshot::channel();
288 let send = self.connection_state(receiver_id).and_then(|connection| {
289 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
290 connection
291 .response_channels
292 .lock()
293 .as_mut()
294 .ok_or_else(|| anyhow!("connection was closed"))?
295 .insert(message_id, tx);
296 connection
297 .outgoing_tx
298 .unbounded_send(proto::Message::Envelope(request.into_envelope(
299 message_id,
300 None,
301 original_sender_id.map(|id| id.0),
302 )))
303 .map_err(|_| anyhow!("connection was closed"))?;
304 Ok(())
305 });
306 async move {
307 send?;
308 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
309 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
310 Err(anyhow!("RPC request failed - {}", error.message))
311 } else {
312 T::Response::from_envelope(response)
313 .ok_or_else(|| anyhow!("received response of the wrong type"))
314 }
315 }
316 }
317
318 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
319 let connection = self.connection_state(receiver_id)?;
320 let message_id = connection
321 .next_message_id
322 .fetch_add(1, atomic::Ordering::SeqCst);
323 connection
324 .outgoing_tx
325 .unbounded_send(proto::Message::Envelope(
326 message.into_envelope(message_id, None, None),
327 ))?;
328 Ok(())
329 }
330
331 pub fn forward_send<T: EnvelopedMessage>(
332 &self,
333 sender_id: ConnectionId,
334 receiver_id: ConnectionId,
335 message: T,
336 ) -> Result<()> {
337 let connection = self.connection_state(receiver_id)?;
338 let message_id = connection
339 .next_message_id
340 .fetch_add(1, atomic::Ordering::SeqCst);
341 connection
342 .outgoing_tx
343 .unbounded_send(proto::Message::Envelope(message.into_envelope(
344 message_id,
345 None,
346 Some(sender_id.0),
347 )))?;
348 Ok(())
349 }
350
351 pub fn respond<T: RequestMessage>(
352 &self,
353 receipt: Receipt<T>,
354 response: T::Response,
355 ) -> Result<()> {
356 let connection = self.connection_state(receipt.sender_id)?;
357 let message_id = connection
358 .next_message_id
359 .fetch_add(1, atomic::Ordering::SeqCst);
360 connection
361 .outgoing_tx
362 .unbounded_send(proto::Message::Envelope(response.into_envelope(
363 message_id,
364 Some(receipt.message_id),
365 None,
366 )))?;
367 Ok(())
368 }
369
370 pub fn respond_with_error<T: RequestMessage>(
371 &self,
372 receipt: Receipt<T>,
373 response: proto::Error,
374 ) -> Result<()> {
375 let connection = self.connection_state(receipt.sender_id)?;
376 let message_id = connection
377 .next_message_id
378 .fetch_add(1, atomic::Ordering::SeqCst);
379 connection
380 .outgoing_tx
381 .unbounded_send(proto::Message::Envelope(response.into_envelope(
382 message_id,
383 Some(receipt.message_id),
384 None,
385 )))?;
386 Ok(())
387 }
388
389 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
390 let connections = self.connections.read();
391 let connection = connections
392 .get(&connection_id)
393 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
394 Ok(connection.clone())
395 }
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use crate::TypedEnvelope;
402 use async_tungstenite::tungstenite::Message as WebSocketMessage;
403 use gpui::TestAppContext;
404
405 #[gpui::test(iterations = 50)]
406 async fn test_request_response(cx: &mut TestAppContext) {
407 let executor = cx.foreground();
408
409 // create 2 clients connected to 1 server
410 let server = Peer::new();
411 let client1 = Peer::new();
412 let client2 = Peer::new();
413
414 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
415 Connection::in_memory(cx.background());
416 let (client1_conn_id, io_task1, client1_incoming) = client1
417 .add_test_connection(client1_to_server_conn, cx.background())
418 .await;
419 let (_, io_task2, server_incoming1) = server
420 .add_test_connection(server_to_client_1_conn, cx.background())
421 .await;
422
423 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
424 Connection::in_memory(cx.background());
425 let (client2_conn_id, io_task3, client2_incoming) = client2
426 .add_test_connection(client2_to_server_conn, cx.background())
427 .await;
428 let (_, io_task4, server_incoming2) = server
429 .add_test_connection(server_to_client_2_conn, cx.background())
430 .await;
431
432 executor.spawn(io_task1).detach();
433 executor.spawn(io_task2).detach();
434 executor.spawn(io_task3).detach();
435 executor.spawn(io_task4).detach();
436 executor
437 .spawn(handle_messages(server_incoming1, server.clone()))
438 .detach();
439 executor
440 .spawn(handle_messages(client1_incoming, client1.clone()))
441 .detach();
442 executor
443 .spawn(handle_messages(server_incoming2, server.clone()))
444 .detach();
445 executor
446 .spawn(handle_messages(client2_incoming, client2.clone()))
447 .detach();
448
449 assert_eq!(
450 client1
451 .request(client1_conn_id, proto::Ping {},)
452 .await
453 .unwrap(),
454 proto::Ack {}
455 );
456
457 assert_eq!(
458 client2
459 .request(client2_conn_id, proto::Ping {},)
460 .await
461 .unwrap(),
462 proto::Ack {}
463 );
464
465 assert_eq!(
466 client1
467 .request(client1_conn_id, proto::Test { id: 1 },)
468 .await
469 .unwrap(),
470 proto::Test { id: 1 }
471 );
472
473 assert_eq!(
474 client2
475 .request(client2_conn_id, proto::Test { id: 2 })
476 .await
477 .unwrap(),
478 proto::Test { id: 2 }
479 );
480
481 client1.disconnect(client1_conn_id);
482 client2.disconnect(client1_conn_id);
483
484 async fn handle_messages(
485 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
486 peer: Arc<Peer>,
487 ) -> Result<()> {
488 while let Some(envelope) = messages.next().await {
489 let envelope = envelope.into_any();
490 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
491 let receipt = envelope.receipt();
492 peer.respond(receipt, proto::Ack {})?
493 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
494 {
495 peer.respond(envelope.receipt(), envelope.payload.clone())?
496 } else {
497 panic!("unknown message type");
498 }
499 }
500
501 Ok(())
502 }
503 }
504
505 #[gpui::test(iterations = 50)]
506 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
507 let executor = cx.foreground();
508 let server = Peer::new();
509 let client = Peer::new();
510
511 let (client_to_server_conn, server_to_client_conn, _kill) =
512 Connection::in_memory(cx.background());
513 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
514 .add_test_connection(client_to_server_conn, cx.background())
515 .await;
516 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
517 .add_test_connection(server_to_client_conn, cx.background())
518 .await;
519
520 executor.spawn(io_task1).detach();
521 executor.spawn(io_task2).detach();
522
523 executor
524 .spawn(async move {
525 let request = server_incoming
526 .next()
527 .await
528 .unwrap()
529 .into_any()
530 .downcast::<TypedEnvelope<proto::Ping>>()
531 .unwrap();
532
533 server
534 .send(
535 server_to_client_conn_id,
536 proto::Error {
537 message: "message 1".to_string(),
538 },
539 )
540 .unwrap();
541 server
542 .send(
543 server_to_client_conn_id,
544 proto::Error {
545 message: "message 2".to_string(),
546 },
547 )
548 .unwrap();
549 server.respond(request.receipt(), proto::Ack {}).unwrap();
550
551 // Prevent the connection from being dropped
552 server_incoming.next().await;
553 })
554 .detach();
555
556 let events = Arc::new(Mutex::new(Vec::new()));
557
558 let response = client.request(client_to_server_conn_id, proto::Ping {});
559 let response_task = executor.spawn({
560 let events = events.clone();
561 async move {
562 response.await.unwrap();
563 events.lock().push("response".to_string());
564 }
565 });
566
567 executor
568 .spawn({
569 let events = events.clone();
570 async move {
571 let incoming1 = client_incoming
572 .next()
573 .await
574 .unwrap()
575 .into_any()
576 .downcast::<TypedEnvelope<proto::Error>>()
577 .unwrap();
578 events.lock().push(incoming1.payload.message);
579 let incoming2 = client_incoming
580 .next()
581 .await
582 .unwrap()
583 .into_any()
584 .downcast::<TypedEnvelope<proto::Error>>()
585 .unwrap();
586 events.lock().push(incoming2.payload.message);
587
588 // Prevent the connection from being dropped
589 client_incoming.next().await;
590 }
591 })
592 .detach();
593
594 response_task.await;
595 assert_eq!(
596 &*events.lock(),
597 &[
598 "message 1".to_string(),
599 "message 2".to_string(),
600 "response".to_string()
601 ]
602 );
603 }
604
605 #[gpui::test(iterations = 50)]
606 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
607 let executor = cx.foreground();
608 let server = Peer::new();
609 let client = Peer::new();
610
611 let (client_to_server_conn, server_to_client_conn, _kill) =
612 Connection::in_memory(cx.background());
613 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
614 .add_test_connection(client_to_server_conn, cx.background())
615 .await;
616 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
617 .add_test_connection(server_to_client_conn, cx.background())
618 .await;
619
620 executor.spawn(io_task1).detach();
621 executor.spawn(io_task2).detach();
622
623 executor
624 .spawn(async move {
625 let request1 = server_incoming
626 .next()
627 .await
628 .unwrap()
629 .into_any()
630 .downcast::<TypedEnvelope<proto::Ping>>()
631 .unwrap();
632 let request2 = server_incoming
633 .next()
634 .await
635 .unwrap()
636 .into_any()
637 .downcast::<TypedEnvelope<proto::Ping>>()
638 .unwrap();
639
640 server
641 .send(
642 server_to_client_conn_id,
643 proto::Error {
644 message: "message 1".to_string(),
645 },
646 )
647 .unwrap();
648 server
649 .send(
650 server_to_client_conn_id,
651 proto::Error {
652 message: "message 2".to_string(),
653 },
654 )
655 .unwrap();
656 server.respond(request1.receipt(), proto::Ack {}).unwrap();
657 server.respond(request2.receipt(), proto::Ack {}).unwrap();
658
659 // Prevent the connection from being dropped
660 server_incoming.next().await;
661 })
662 .detach();
663
664 let events = Arc::new(Mutex::new(Vec::new()));
665
666 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
667 let request1_task = executor.spawn(request1);
668 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
669 let request2_task = executor.spawn({
670 let events = events.clone();
671 async move {
672 request2.await.unwrap();
673 events.lock().push("response 2".to_string());
674 }
675 });
676
677 executor
678 .spawn({
679 let events = events.clone();
680 async move {
681 let incoming1 = client_incoming
682 .next()
683 .await
684 .unwrap()
685 .into_any()
686 .downcast::<TypedEnvelope<proto::Error>>()
687 .unwrap();
688 events.lock().push(incoming1.payload.message);
689 let incoming2 = client_incoming
690 .next()
691 .await
692 .unwrap()
693 .into_any()
694 .downcast::<TypedEnvelope<proto::Error>>()
695 .unwrap();
696 events.lock().push(incoming2.payload.message);
697
698 // Prevent the connection from being dropped
699 client_incoming.next().await;
700 }
701 })
702 .detach();
703
704 // Allow the request to make some progress before dropping it.
705 cx.background().simulate_random_delay().await;
706 drop(request1_task);
707
708 request2_task.await;
709 assert_eq!(
710 &*events.lock(),
711 &[
712 "message 1".to_string(),
713 "message 2".to_string(),
714 "response 2".to_string()
715 ]
716 );
717 }
718
719 #[gpui::test(iterations = 50)]
720 async fn test_disconnect(cx: &mut TestAppContext) {
721 let executor = cx.foreground();
722
723 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
724
725 let client = Peer::new();
726 let (connection_id, io_handler, mut incoming) = client
727 .add_test_connection(client_conn, cx.background())
728 .await;
729
730 let (io_ended_tx, io_ended_rx) = oneshot::channel();
731 executor
732 .spawn(async move {
733 io_handler.await.ok();
734 io_ended_tx.send(()).unwrap();
735 })
736 .detach();
737
738 let (messages_ended_tx, messages_ended_rx) = oneshot::channel();
739 executor
740 .spawn(async move {
741 incoming.next().await;
742 messages_ended_tx.send(()).unwrap();
743 })
744 .detach();
745
746 client.disconnect(connection_id);
747
748 let _ = io_ended_rx.await;
749 let _ = messages_ended_rx.await;
750 assert!(server_conn
751 .send(WebSocketMessage::Binary(vec![]))
752 .await
753 .is_err());
754 }
755
756 #[gpui::test(iterations = 50)]
757 async fn test_io_error(cx: &mut TestAppContext) {
758 let executor = cx.foreground();
759 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
760
761 let client = Peer::new();
762 let (connection_id, io_handler, mut incoming) = client
763 .add_test_connection(client_conn, cx.background())
764 .await;
765 executor.spawn(io_handler).detach();
766 executor
767 .spawn(async move { incoming.next().await })
768 .detach();
769
770 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
771 let _request = server_conn.rx.next().await.unwrap().unwrap();
772
773 drop(server_conn);
774 assert_eq!(
775 response.await.unwrap_err().to_string(),
776 "connection was closed"
777 );
778 }
779}