1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::{channel::oneshot, stream::BoxStream, FutureExt as _, StreamExt};
5use parking_lot::{Mutex, RwLock};
6use postage::{
7 barrier, mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use smol_timeout::TimeoutExt as _;
11use std::sync::atomic::Ordering::SeqCst;
12use std::{
13 collections::HashMap,
14 fmt,
15 future::Future,
16 marker::PhantomData,
17 sync::{
18 atomic::{self, AtomicU32},
19 Arc,
20 },
21 time::Duration,
22};
23
24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
25pub struct ConnectionId(pub u32);
26
27impl fmt::Display for ConnectionId {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 self.0.fmt(f)
30 }
31}
32
33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
34pub struct PeerId(pub u32);
35
36impl fmt::Display for PeerId {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 self.0.fmt(f)
39 }
40}
41
42pub struct Receipt<T> {
43 pub sender_id: ConnectionId,
44 pub message_id: u32,
45 payload_type: PhantomData<T>,
46}
47
48impl<T> Clone for Receipt<T> {
49 fn clone(&self) -> Self {
50 Self {
51 sender_id: self.sender_id,
52 message_id: self.message_id,
53 payload_type: PhantomData,
54 }
55 }
56}
57
58impl<T> Copy for Receipt<T> {}
59
60pub struct TypedEnvelope<T> {
61 pub sender_id: ConnectionId,
62 pub original_sender_id: Option<PeerId>,
63 pub message_id: u32,
64 pub payload: T,
65}
66
67impl<T> TypedEnvelope<T> {
68 pub fn original_sender_id(&self) -> Result<PeerId> {
69 self.original_sender_id
70 .ok_or_else(|| anyhow!("missing original_sender_id"))
71 }
72}
73
74impl<T: RequestMessage> TypedEnvelope<T> {
75 pub fn receipt(&self) -> Receipt<T> {
76 Receipt {
77 sender_id: self.sender_id,
78 message_id: self.message_id,
79 payload_type: PhantomData,
80 }
81 }
82}
83
84pub struct Peer {
85 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
86 next_connection_id: AtomicU32,
87}
88
89#[derive(Clone)]
90pub struct ConnectionState {
91 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
92 next_message_id: Arc<AtomicU32>,
93 response_channels:
94 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
95}
96
97const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2);
98const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
99
100impl Peer {
101 pub fn new() -> Arc<Self> {
102 Arc::new(Self {
103 connections: Default::default(),
104 next_connection_id: Default::default(),
105 })
106 }
107
108 pub async fn add_connection<F, Fut, Out>(
109 self: &Arc<Self>,
110 connection: Connection,
111 create_timer: F,
112 ) -> (
113 ConnectionId,
114 impl Future<Output = anyhow::Result<()>> + Send,
115 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
116 )
117 where
118 F: Send + Fn(Duration) -> Fut,
119 Fut: Send + Future<Output = Out>,
120 Out: Send,
121 {
122 // For outgoing messages, use an unbounded channel so that application code
123 // can always send messages without yielding. For incoming messages, use a
124 // bounded channel so that other peers will receive backpressure if they send
125 // messages faster than this peer can process them.
126 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
127 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
128
129 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
130 let connection_state = ConnectionState {
131 outgoing_tx: outgoing_tx.clone(),
132 next_message_id: Default::default(),
133 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
134 };
135 let mut writer = MessageStream::new(connection.tx);
136 let mut reader = MessageStream::new(connection.rx);
137
138 let this = self.clone();
139 let response_channels = connection_state.response_channels.clone();
140 let handle_io = async move {
141 let _end_connection = util::defer(|| {
142 response_channels.lock().take();
143 this.connections.write().remove(&connection_id);
144 });
145
146 loop {
147 let read_message = reader.read_message().fuse();
148 futures::pin_mut!(read_message);
149 loop {
150 futures::select_biased! {
151 outgoing = outgoing_rx.next().fuse() => match outgoing {
152 Some(outgoing) => {
153 if let Some(result) = writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
154 result.context("failed to write RPC message")?;
155 } else {
156 Err(anyhow!("timed out writing message"))?;
157 }
158 }
159 None => return Ok(()),
160 },
161 incoming = read_message => {
162 let incoming = incoming.context("received invalid rpc message")?;
163 if incoming_tx.send(incoming).await.is_err() {
164 return Ok(());
165 }
166 break;
167 },
168 _ = create_timer(KEEPALIVE_INTERVAL).fuse() => {
169 if let Some(result) = writer.ping().timeout(WRITE_TIMEOUT).await {
170 result.context("failed to send websocket ping")?;
171 } else {
172 Err(anyhow!("timed out sending websocket ping"))?;
173 }
174 }
175 }
176 }
177 }
178 };
179
180 let response_channels = connection_state.response_channels.clone();
181 self.connections
182 .write()
183 .insert(connection_id, connection_state);
184
185 let incoming_rx = incoming_rx.filter_map(move |incoming| {
186 let response_channels = response_channels.clone();
187 async move {
188 if let Some(responding_to) = incoming.responding_to {
189 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
190 if let Some(tx) = channel {
191 let mut requester_resumed = barrier::channel();
192 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
193 log::debug!(
194 "received RPC but request future was dropped {:?}",
195 error.0
196 );
197 }
198 requester_resumed.1.recv().await;
199 } else {
200 log::warn!("received RPC response to unknown request {}", responding_to);
201 }
202
203 None
204 } else {
205 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
206 log::error!("unable to construct a typed envelope");
207 None
208 })
209 }
210 }
211 });
212 (connection_id, handle_io, incoming_rx.boxed())
213 }
214
215 #[cfg(any(test, feature = "test-support"))]
216 pub async fn add_test_connection(
217 self: &Arc<Self>,
218 connection: Connection,
219 executor: Arc<gpui::executor::Background>,
220 ) -> (
221 ConnectionId,
222 impl Future<Output = anyhow::Result<()>> + Send,
223 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
224 ) {
225 let executor = executor.clone();
226 self.add_connection(connection, move |duration| executor.timer(duration))
227 .await
228 }
229
230 pub fn disconnect(&self, connection_id: ConnectionId) {
231 self.connections.write().remove(&connection_id);
232 }
233
234 pub fn reset(&self) {
235 self.connections.write().clear();
236 }
237
238 pub fn request<T: RequestMessage>(
239 &self,
240 receiver_id: ConnectionId,
241 request: T,
242 ) -> impl Future<Output = Result<T::Response>> {
243 self.request_internal(None, receiver_id, request)
244 }
245
246 pub fn forward_request<T: RequestMessage>(
247 &self,
248 sender_id: ConnectionId,
249 receiver_id: ConnectionId,
250 request: T,
251 ) -> impl Future<Output = Result<T::Response>> {
252 self.request_internal(Some(sender_id), receiver_id, request)
253 }
254
255 pub fn request_internal<T: RequestMessage>(
256 &self,
257 original_sender_id: Option<ConnectionId>,
258 receiver_id: ConnectionId,
259 request: T,
260 ) -> impl Future<Output = Result<T::Response>> {
261 let (tx, rx) = oneshot::channel();
262 let send = self.connection_state(receiver_id).and_then(|connection| {
263 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
264 connection
265 .response_channels
266 .lock()
267 .as_mut()
268 .ok_or_else(|| anyhow!("connection was closed"))?
269 .insert(message_id, tx);
270 connection
271 .outgoing_tx
272 .unbounded_send(request.into_envelope(
273 message_id,
274 None,
275 original_sender_id.map(|id| id.0),
276 ))
277 .map_err(|_| anyhow!("connection was closed"))?;
278 Ok(())
279 });
280 async move {
281 send?;
282 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
283 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
284 Err(anyhow!("RPC request failed - {}", error.message))
285 } else {
286 T::Response::from_envelope(response)
287 .ok_or_else(|| anyhow!("received response of the wrong type"))
288 }
289 }
290 }
291
292 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
293 let connection = self.connection_state(receiver_id)?;
294 let message_id = connection
295 .next_message_id
296 .fetch_add(1, atomic::Ordering::SeqCst);
297 connection
298 .outgoing_tx
299 .unbounded_send(message.into_envelope(message_id, None, None))?;
300 Ok(())
301 }
302
303 pub fn forward_send<T: EnvelopedMessage>(
304 &self,
305 sender_id: ConnectionId,
306 receiver_id: ConnectionId,
307 message: T,
308 ) -> Result<()> {
309 let connection = self.connection_state(receiver_id)?;
310 let message_id = connection
311 .next_message_id
312 .fetch_add(1, atomic::Ordering::SeqCst);
313 connection
314 .outgoing_tx
315 .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
316 Ok(())
317 }
318
319 pub fn respond<T: RequestMessage>(
320 &self,
321 receipt: Receipt<T>,
322 response: T::Response,
323 ) -> Result<()> {
324 let connection = self.connection_state(receipt.sender_id)?;
325 let message_id = connection
326 .next_message_id
327 .fetch_add(1, atomic::Ordering::SeqCst);
328 connection
329 .outgoing_tx
330 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
331 Ok(())
332 }
333
334 pub fn respond_with_error<T: RequestMessage>(
335 &self,
336 receipt: Receipt<T>,
337 response: proto::Error,
338 ) -> Result<()> {
339 let connection = self.connection_state(receipt.sender_id)?;
340 let message_id = connection
341 .next_message_id
342 .fetch_add(1, atomic::Ordering::SeqCst);
343 connection
344 .outgoing_tx
345 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
346 Ok(())
347 }
348
349 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
350 let connections = self.connections.read();
351 let connection = connections
352 .get(&connection_id)
353 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
354 Ok(connection.clone())
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use crate::TypedEnvelope;
362 use async_tungstenite::tungstenite::Message as WebSocketMessage;
363 use gpui::TestAppContext;
364
365 #[gpui::test(iterations = 50)]
366 async fn test_request_response(cx: &mut TestAppContext) {
367 let executor = cx.foreground();
368
369 // create 2 clients connected to 1 server
370 let server = Peer::new();
371 let client1 = Peer::new();
372 let client2 = Peer::new();
373
374 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
375 Connection::in_memory(cx.background());
376 let (client1_conn_id, io_task1, client1_incoming) = client1
377 .add_test_connection(client1_to_server_conn, cx.background())
378 .await;
379 let (_, io_task2, server_incoming1) = server
380 .add_test_connection(server_to_client_1_conn, cx.background())
381 .await;
382
383 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
384 Connection::in_memory(cx.background());
385 let (client2_conn_id, io_task3, client2_incoming) = client2
386 .add_test_connection(client2_to_server_conn, cx.background())
387 .await;
388 let (_, io_task4, server_incoming2) = server
389 .add_test_connection(server_to_client_2_conn, cx.background())
390 .await;
391
392 executor.spawn(io_task1).detach();
393 executor.spawn(io_task2).detach();
394 executor.spawn(io_task3).detach();
395 executor.spawn(io_task4).detach();
396 executor
397 .spawn(handle_messages(server_incoming1, server.clone()))
398 .detach();
399 executor
400 .spawn(handle_messages(client1_incoming, client1.clone()))
401 .detach();
402 executor
403 .spawn(handle_messages(server_incoming2, server.clone()))
404 .detach();
405 executor
406 .spawn(handle_messages(client2_incoming, client2.clone()))
407 .detach();
408
409 assert_eq!(
410 client1
411 .request(client1_conn_id, proto::Ping {},)
412 .await
413 .unwrap(),
414 proto::Ack {}
415 );
416
417 assert_eq!(
418 client2
419 .request(client2_conn_id, proto::Ping {},)
420 .await
421 .unwrap(),
422 proto::Ack {}
423 );
424
425 assert_eq!(
426 client1
427 .request(client1_conn_id, proto::Test { id: 1 },)
428 .await
429 .unwrap(),
430 proto::Test { id: 1 }
431 );
432
433 assert_eq!(
434 client2
435 .request(client2_conn_id, proto::Test { id: 2 })
436 .await
437 .unwrap(),
438 proto::Test { id: 2 }
439 );
440
441 client1.disconnect(client1_conn_id);
442 client2.disconnect(client1_conn_id);
443
444 async fn handle_messages(
445 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
446 peer: Arc<Peer>,
447 ) -> Result<()> {
448 while let Some(envelope) = messages.next().await {
449 let envelope = envelope.into_any();
450 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
451 let receipt = envelope.receipt();
452 peer.respond(receipt, proto::Ack {})?
453 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
454 {
455 peer.respond(envelope.receipt(), envelope.payload.clone())?
456 } else {
457 panic!("unknown message type");
458 }
459 }
460
461 Ok(())
462 }
463 }
464
465 #[gpui::test(iterations = 50)]
466 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
467 let executor = cx.foreground();
468 let server = Peer::new();
469 let client = Peer::new();
470
471 let (client_to_server_conn, server_to_client_conn, _kill) =
472 Connection::in_memory(cx.background());
473 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
474 .add_test_connection(client_to_server_conn, cx.background())
475 .await;
476 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
477 .add_test_connection(server_to_client_conn, cx.background())
478 .await;
479
480 executor.spawn(io_task1).detach();
481 executor.spawn(io_task2).detach();
482
483 executor
484 .spawn(async move {
485 let request = server_incoming
486 .next()
487 .await
488 .unwrap()
489 .into_any()
490 .downcast::<TypedEnvelope<proto::Ping>>()
491 .unwrap();
492
493 server
494 .send(
495 server_to_client_conn_id,
496 proto::Error {
497 message: "message 1".to_string(),
498 },
499 )
500 .unwrap();
501 server
502 .send(
503 server_to_client_conn_id,
504 proto::Error {
505 message: "message 2".to_string(),
506 },
507 )
508 .unwrap();
509 server.respond(request.receipt(), proto::Ack {}).unwrap();
510
511 // Prevent the connection from being dropped
512 server_incoming.next().await;
513 })
514 .detach();
515
516 let events = Arc::new(Mutex::new(Vec::new()));
517
518 let response = client.request(client_to_server_conn_id, proto::Ping {});
519 let response_task = executor.spawn({
520 let events = events.clone();
521 async move {
522 response.await.unwrap();
523 events.lock().push("response".to_string());
524 }
525 });
526
527 executor
528 .spawn({
529 let events = events.clone();
530 async move {
531 let incoming1 = client_incoming
532 .next()
533 .await
534 .unwrap()
535 .into_any()
536 .downcast::<TypedEnvelope<proto::Error>>()
537 .unwrap();
538 events.lock().push(incoming1.payload.message);
539 let incoming2 = client_incoming
540 .next()
541 .await
542 .unwrap()
543 .into_any()
544 .downcast::<TypedEnvelope<proto::Error>>()
545 .unwrap();
546 events.lock().push(incoming2.payload.message);
547
548 // Prevent the connection from being dropped
549 client_incoming.next().await;
550 }
551 })
552 .detach();
553
554 response_task.await;
555 assert_eq!(
556 &*events.lock(),
557 &[
558 "message 1".to_string(),
559 "message 2".to_string(),
560 "response".to_string()
561 ]
562 );
563 }
564
565 #[gpui::test(iterations = 50)]
566 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
567 let executor = cx.foreground();
568 let server = Peer::new();
569 let client = Peer::new();
570
571 let (client_to_server_conn, server_to_client_conn, _kill) =
572 Connection::in_memory(cx.background());
573 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
574 .add_test_connection(client_to_server_conn, cx.background())
575 .await;
576 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
577 .add_test_connection(server_to_client_conn, cx.background())
578 .await;
579
580 executor.spawn(io_task1).detach();
581 executor.spawn(io_task2).detach();
582
583 executor
584 .spawn(async move {
585 let request1 = server_incoming
586 .next()
587 .await
588 .unwrap()
589 .into_any()
590 .downcast::<TypedEnvelope<proto::Ping>>()
591 .unwrap();
592 let request2 = server_incoming
593 .next()
594 .await
595 .unwrap()
596 .into_any()
597 .downcast::<TypedEnvelope<proto::Ping>>()
598 .unwrap();
599
600 server
601 .send(
602 server_to_client_conn_id,
603 proto::Error {
604 message: "message 1".to_string(),
605 },
606 )
607 .unwrap();
608 server
609 .send(
610 server_to_client_conn_id,
611 proto::Error {
612 message: "message 2".to_string(),
613 },
614 )
615 .unwrap();
616 server.respond(request1.receipt(), proto::Ack {}).unwrap();
617 server.respond(request2.receipt(), proto::Ack {}).unwrap();
618
619 // Prevent the connection from being dropped
620 server_incoming.next().await;
621 })
622 .detach();
623
624 let events = Arc::new(Mutex::new(Vec::new()));
625
626 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
627 let request1_task = executor.spawn(request1);
628 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
629 let request2_task = executor.spawn({
630 let events = events.clone();
631 async move {
632 request2.await.unwrap();
633 events.lock().push("response 2".to_string());
634 }
635 });
636
637 executor
638 .spawn({
639 let events = events.clone();
640 async move {
641 let incoming1 = client_incoming
642 .next()
643 .await
644 .unwrap()
645 .into_any()
646 .downcast::<TypedEnvelope<proto::Error>>()
647 .unwrap();
648 events.lock().push(incoming1.payload.message);
649 let incoming2 = client_incoming
650 .next()
651 .await
652 .unwrap()
653 .into_any()
654 .downcast::<TypedEnvelope<proto::Error>>()
655 .unwrap();
656 events.lock().push(incoming2.payload.message);
657
658 // Prevent the connection from being dropped
659 client_incoming.next().await;
660 }
661 })
662 .detach();
663
664 // Allow the request to make some progress before dropping it.
665 cx.background().simulate_random_delay().await;
666 drop(request1_task);
667
668 request2_task.await;
669 assert_eq!(
670 &*events.lock(),
671 &[
672 "message 1".to_string(),
673 "message 2".to_string(),
674 "response 2".to_string()
675 ]
676 );
677 }
678
679 #[gpui::test(iterations = 50)]
680 async fn test_disconnect(cx: &mut TestAppContext) {
681 let executor = cx.foreground();
682
683 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
684
685 let client = Peer::new();
686 let (connection_id, io_handler, mut incoming) = client
687 .add_test_connection(client_conn, cx.background())
688 .await;
689
690 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
691 executor
692 .spawn(async move {
693 io_handler.await.ok();
694 io_ended_tx.send(()).await.unwrap();
695 })
696 .detach();
697
698 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
699 executor
700 .spawn(async move {
701 incoming.next().await;
702 messages_ended_tx.send(()).await.unwrap();
703 })
704 .detach();
705
706 client.disconnect(connection_id);
707
708 io_ended_rx.recv().await;
709 messages_ended_rx.recv().await;
710 assert!(server_conn
711 .send(WebSocketMessage::Binary(vec![]))
712 .await
713 .is_err());
714 }
715
716 #[gpui::test(iterations = 50)]
717 async fn test_io_error(cx: &mut TestAppContext) {
718 let executor = cx.foreground();
719 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
720
721 let client = Peer::new();
722 let (connection_id, io_handler, mut incoming) = client
723 .add_test_connection(client_conn, cx.background())
724 .await;
725 executor.spawn(io_handler).detach();
726 executor
727 .spawn(async move { incoming.next().await })
728 .detach();
729
730 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
731 let _request = server_conn.rx.next().await.unwrap().unwrap();
732
733 drop(server_conn);
734 assert_eq!(
735 response.await.unwrap_err().to_string(),
736 "connection was closed"
737 );
738 }
739}