1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::{channel::oneshot, stream::BoxStream, FutureExt as _, StreamExt};
5use parking_lot::{Mutex, RwLock};
6use postage::{
7 barrier, mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use smol_timeout::TimeoutExt as _;
11use std::sync::atomic::Ordering::SeqCst;
12use std::{
13 collections::HashMap,
14 fmt,
15 future::Future,
16 marker::PhantomData,
17 sync::{
18 atomic::{self, AtomicU32},
19 Arc,
20 },
21 time::Duration,
22};
23
24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
25pub struct ConnectionId(pub u32);
26
27impl fmt::Display for ConnectionId {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 self.0.fmt(f)
30 }
31}
32
33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
34pub struct PeerId(pub u32);
35
36impl fmt::Display for PeerId {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 self.0.fmt(f)
39 }
40}
41
42pub struct Receipt<T> {
43 pub sender_id: ConnectionId,
44 pub message_id: u32,
45 payload_type: PhantomData<T>,
46}
47
48impl<T> Clone for Receipt<T> {
49 fn clone(&self) -> Self {
50 Self {
51 sender_id: self.sender_id,
52 message_id: self.message_id,
53 payload_type: PhantomData,
54 }
55 }
56}
57
58impl<T> Copy for Receipt<T> {}
59
60pub struct TypedEnvelope<T> {
61 pub sender_id: ConnectionId,
62 pub original_sender_id: Option<PeerId>,
63 pub message_id: u32,
64 pub payload: T,
65}
66
67impl<T> TypedEnvelope<T> {
68 pub fn original_sender_id(&self) -> Result<PeerId> {
69 self.original_sender_id
70 .ok_or_else(|| anyhow!("missing original_sender_id"))
71 }
72}
73
74impl<T: RequestMessage> TypedEnvelope<T> {
75 pub fn receipt(&self) -> Receipt<T> {
76 Receipt {
77 sender_id: self.sender_id,
78 message_id: self.message_id,
79 payload_type: PhantomData,
80 }
81 }
82}
83
84pub struct Peer {
85 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
86 next_connection_id: AtomicU32,
87}
88
89#[derive(Clone)]
90pub struct ConnectionState {
91 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
92 next_message_id: Arc<AtomicU32>,
93 response_channels:
94 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
95}
96
97const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
98
99impl Peer {
100 pub fn new() -> Arc<Self> {
101 Arc::new(Self {
102 connections: Default::default(),
103 next_connection_id: Default::default(),
104 })
105 }
106
107 pub async fn add_connection<F, Fut, Out>(
108 self: &Arc<Self>,
109 connection: Connection,
110 create_timer: F,
111 ) -> (
112 ConnectionId,
113 impl Future<Output = anyhow::Result<()>> + Send,
114 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
115 )
116 where
117 F: Send + Fn(Duration) -> Fut,
118 Fut: Send + Future<Output = Out>,
119 Out: Send,
120 {
121 // For outgoing messages, use an unbounded channel so that application code
122 // can always send messages without yielding. For incoming messages, use a
123 // bounded channel so that other peers will receive backpressure if they send
124 // messages faster than this peer can process them.
125 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
126 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
127
128 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
129 let connection_state = ConnectionState {
130 outgoing_tx: outgoing_tx.clone(),
131 next_message_id: Default::default(),
132 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
133 };
134 let mut writer = MessageStream::new(connection.tx);
135 let mut reader = MessageStream::new(connection.rx);
136
137 let this = self.clone();
138 let response_channels = connection_state.response_channels.clone();
139 let handle_io = async move {
140 let _end_connection = util::defer(|| {
141 response_channels.lock().take();
142 this.connections.write().remove(&connection_id);
143 });
144
145 loop {
146 let read_message = reader.read().fuse();
147 futures::pin_mut!(read_message);
148 let read_timeout = create_timer(2 * KEEPALIVE_INTERVAL).fuse();
149 futures::pin_mut!(read_timeout);
150
151 loop {
152 futures::select_biased! {
153 outgoing = outgoing_rx.next().fuse() => match outgoing {
154 Some(outgoing) => {
155 let outgoing = proto::Message::Envelope(outgoing);
156 if let Some(result) = writer.write(outgoing).timeout(2 * KEEPALIVE_INTERVAL).await {
157 result.context("failed to write RPC message")?;
158 } else {
159 Err(anyhow!("timed out writing message"))?;
160 }
161 }
162 None => return Ok(()),
163 },
164 incoming = read_message => {
165 let incoming = incoming.context("received invalid RPC message")?;
166 if let proto::Message::Envelope(incoming) = incoming {
167 if incoming_tx.send(incoming).await.is_err() {
168 return Ok(());
169 }
170 }
171
172 break;
173 },
174 _ = create_timer(KEEPALIVE_INTERVAL).fuse() => {
175 if let Some(result) = writer.write(proto::Message::Ping).timeout(2 * KEEPALIVE_INTERVAL).await {
176 result.context("failed to send websocket ping")?;
177 } else {
178 Err(anyhow!("timed out sending websocket ping"))?;
179 }
180 }
181 _ = read_timeout => {
182 Err(anyhow!("timed out reading message"))?
183 }
184 }
185 }
186 }
187 };
188
189 let response_channels = connection_state.response_channels.clone();
190 self.connections
191 .write()
192 .insert(connection_id, connection_state);
193
194 let incoming_rx = incoming_rx.filter_map(move |incoming| {
195 let response_channels = response_channels.clone();
196 async move {
197 if let Some(responding_to) = incoming.responding_to {
198 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
199 if let Some(tx) = channel {
200 let mut requester_resumed = barrier::channel();
201 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
202 log::debug!(
203 "received RPC but request future was dropped {:?}",
204 error.0
205 );
206 }
207 requester_resumed.1.recv().await;
208 } else {
209 log::warn!("received RPC response to unknown request {}", responding_to);
210 }
211
212 None
213 } else {
214 proto::build_typed_envelope(connection_id, incoming).or_else(|| {
215 log::error!("unable to construct a typed envelope");
216 None
217 })
218 }
219 }
220 });
221 (connection_id, handle_io, incoming_rx.boxed())
222 }
223
224 #[cfg(any(test, feature = "test-support"))]
225 pub async fn add_test_connection(
226 self: &Arc<Self>,
227 connection: Connection,
228 executor: Arc<gpui::executor::Background>,
229 ) -> (
230 ConnectionId,
231 impl Future<Output = anyhow::Result<()>> + Send,
232 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
233 ) {
234 let executor = executor.clone();
235 self.add_connection(connection, move |duration| executor.timer(duration))
236 .await
237 }
238
239 pub fn disconnect(&self, connection_id: ConnectionId) {
240 self.connections.write().remove(&connection_id);
241 }
242
243 pub fn reset(&self) {
244 self.connections.write().clear();
245 }
246
247 pub fn request<T: RequestMessage>(
248 &self,
249 receiver_id: ConnectionId,
250 request: T,
251 ) -> impl Future<Output = Result<T::Response>> {
252 self.request_internal(None, receiver_id, request)
253 }
254
255 pub fn forward_request<T: RequestMessage>(
256 &self,
257 sender_id: ConnectionId,
258 receiver_id: ConnectionId,
259 request: T,
260 ) -> impl Future<Output = Result<T::Response>> {
261 self.request_internal(Some(sender_id), receiver_id, request)
262 }
263
264 pub fn request_internal<T: RequestMessage>(
265 &self,
266 original_sender_id: Option<ConnectionId>,
267 receiver_id: ConnectionId,
268 request: T,
269 ) -> impl Future<Output = Result<T::Response>> {
270 let (tx, rx) = oneshot::channel();
271 let send = self.connection_state(receiver_id).and_then(|connection| {
272 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
273 connection
274 .response_channels
275 .lock()
276 .as_mut()
277 .ok_or_else(|| anyhow!("connection was closed"))?
278 .insert(message_id, tx);
279 connection
280 .outgoing_tx
281 .unbounded_send(request.into_envelope(
282 message_id,
283 None,
284 original_sender_id.map(|id| id.0),
285 ))
286 .map_err(|_| anyhow!("connection was closed"))?;
287 Ok(())
288 });
289 async move {
290 send?;
291 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
292 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
293 Err(anyhow!("RPC request failed - {}", error.message))
294 } else {
295 T::Response::from_envelope(response)
296 .ok_or_else(|| anyhow!("received response of the wrong type"))
297 }
298 }
299 }
300
301 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
302 let connection = self.connection_state(receiver_id)?;
303 let message_id = connection
304 .next_message_id
305 .fetch_add(1, atomic::Ordering::SeqCst);
306 connection
307 .outgoing_tx
308 .unbounded_send(message.into_envelope(message_id, None, None))?;
309 Ok(())
310 }
311
312 pub fn forward_send<T: EnvelopedMessage>(
313 &self,
314 sender_id: ConnectionId,
315 receiver_id: ConnectionId,
316 message: T,
317 ) -> Result<()> {
318 let connection = self.connection_state(receiver_id)?;
319 let message_id = connection
320 .next_message_id
321 .fetch_add(1, atomic::Ordering::SeqCst);
322 connection
323 .outgoing_tx
324 .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
325 Ok(())
326 }
327
328 pub fn respond<T: RequestMessage>(
329 &self,
330 receipt: Receipt<T>,
331 response: T::Response,
332 ) -> Result<()> {
333 let connection = self.connection_state(receipt.sender_id)?;
334 let message_id = connection
335 .next_message_id
336 .fetch_add(1, atomic::Ordering::SeqCst);
337 connection
338 .outgoing_tx
339 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
340 Ok(())
341 }
342
343 pub fn respond_with_error<T: RequestMessage>(
344 &self,
345 receipt: Receipt<T>,
346 response: proto::Error,
347 ) -> Result<()> {
348 let connection = self.connection_state(receipt.sender_id)?;
349 let message_id = connection
350 .next_message_id
351 .fetch_add(1, atomic::Ordering::SeqCst);
352 connection
353 .outgoing_tx
354 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
355 Ok(())
356 }
357
358 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
359 let connections = self.connections.read();
360 let connection = connections
361 .get(&connection_id)
362 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
363 Ok(connection.clone())
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::TypedEnvelope;
371 use async_tungstenite::tungstenite::Message as WebSocketMessage;
372 use gpui::TestAppContext;
373
374 #[gpui::test(iterations = 50)]
375 async fn test_request_response(cx: &mut TestAppContext) {
376 let executor = cx.foreground();
377
378 // create 2 clients connected to 1 server
379 let server = Peer::new();
380 let client1 = Peer::new();
381 let client2 = Peer::new();
382
383 let (client1_to_server_conn, server_to_client_1_conn, _kill) =
384 Connection::in_memory(cx.background());
385 let (client1_conn_id, io_task1, client1_incoming) = client1
386 .add_test_connection(client1_to_server_conn, cx.background())
387 .await;
388 let (_, io_task2, server_incoming1) = server
389 .add_test_connection(server_to_client_1_conn, cx.background())
390 .await;
391
392 let (client2_to_server_conn, server_to_client_2_conn, _kill) =
393 Connection::in_memory(cx.background());
394 let (client2_conn_id, io_task3, client2_incoming) = client2
395 .add_test_connection(client2_to_server_conn, cx.background())
396 .await;
397 let (_, io_task4, server_incoming2) = server
398 .add_test_connection(server_to_client_2_conn, cx.background())
399 .await;
400
401 executor.spawn(io_task1).detach();
402 executor.spawn(io_task2).detach();
403 executor.spawn(io_task3).detach();
404 executor.spawn(io_task4).detach();
405 executor
406 .spawn(handle_messages(server_incoming1, server.clone()))
407 .detach();
408 executor
409 .spawn(handle_messages(client1_incoming, client1.clone()))
410 .detach();
411 executor
412 .spawn(handle_messages(server_incoming2, server.clone()))
413 .detach();
414 executor
415 .spawn(handle_messages(client2_incoming, client2.clone()))
416 .detach();
417
418 assert_eq!(
419 client1
420 .request(client1_conn_id, proto::Ping {},)
421 .await
422 .unwrap(),
423 proto::Ack {}
424 );
425
426 assert_eq!(
427 client2
428 .request(client2_conn_id, proto::Ping {},)
429 .await
430 .unwrap(),
431 proto::Ack {}
432 );
433
434 assert_eq!(
435 client1
436 .request(client1_conn_id, proto::Test { id: 1 },)
437 .await
438 .unwrap(),
439 proto::Test { id: 1 }
440 );
441
442 assert_eq!(
443 client2
444 .request(client2_conn_id, proto::Test { id: 2 })
445 .await
446 .unwrap(),
447 proto::Test { id: 2 }
448 );
449
450 client1.disconnect(client1_conn_id);
451 client2.disconnect(client1_conn_id);
452
453 async fn handle_messages(
454 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
455 peer: Arc<Peer>,
456 ) -> Result<()> {
457 while let Some(envelope) = messages.next().await {
458 let envelope = envelope.into_any();
459 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
460 let receipt = envelope.receipt();
461 peer.respond(receipt, proto::Ack {})?
462 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
463 {
464 peer.respond(envelope.receipt(), envelope.payload.clone())?
465 } else {
466 panic!("unknown message type");
467 }
468 }
469
470 Ok(())
471 }
472 }
473
474 #[gpui::test(iterations = 50)]
475 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
476 let executor = cx.foreground();
477 let server = Peer::new();
478 let client = Peer::new();
479
480 let (client_to_server_conn, server_to_client_conn, _kill) =
481 Connection::in_memory(cx.background());
482 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
483 .add_test_connection(client_to_server_conn, cx.background())
484 .await;
485 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
486 .add_test_connection(server_to_client_conn, cx.background())
487 .await;
488
489 executor.spawn(io_task1).detach();
490 executor.spawn(io_task2).detach();
491
492 executor
493 .spawn(async move {
494 let request = server_incoming
495 .next()
496 .await
497 .unwrap()
498 .into_any()
499 .downcast::<TypedEnvelope<proto::Ping>>()
500 .unwrap();
501
502 server
503 .send(
504 server_to_client_conn_id,
505 proto::Error {
506 message: "message 1".to_string(),
507 },
508 )
509 .unwrap();
510 server
511 .send(
512 server_to_client_conn_id,
513 proto::Error {
514 message: "message 2".to_string(),
515 },
516 )
517 .unwrap();
518 server.respond(request.receipt(), proto::Ack {}).unwrap();
519
520 // Prevent the connection from being dropped
521 server_incoming.next().await;
522 })
523 .detach();
524
525 let events = Arc::new(Mutex::new(Vec::new()));
526
527 let response = client.request(client_to_server_conn_id, proto::Ping {});
528 let response_task = executor.spawn({
529 let events = events.clone();
530 async move {
531 response.await.unwrap();
532 events.lock().push("response".to_string());
533 }
534 });
535
536 executor
537 .spawn({
538 let events = events.clone();
539 async move {
540 let incoming1 = client_incoming
541 .next()
542 .await
543 .unwrap()
544 .into_any()
545 .downcast::<TypedEnvelope<proto::Error>>()
546 .unwrap();
547 events.lock().push(incoming1.payload.message);
548 let incoming2 = client_incoming
549 .next()
550 .await
551 .unwrap()
552 .into_any()
553 .downcast::<TypedEnvelope<proto::Error>>()
554 .unwrap();
555 events.lock().push(incoming2.payload.message);
556
557 // Prevent the connection from being dropped
558 client_incoming.next().await;
559 }
560 })
561 .detach();
562
563 response_task.await;
564 assert_eq!(
565 &*events.lock(),
566 &[
567 "message 1".to_string(),
568 "message 2".to_string(),
569 "response".to_string()
570 ]
571 );
572 }
573
574 #[gpui::test(iterations = 50)]
575 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
576 let executor = cx.foreground();
577 let server = Peer::new();
578 let client = Peer::new();
579
580 let (client_to_server_conn, server_to_client_conn, _kill) =
581 Connection::in_memory(cx.background());
582 let (client_to_server_conn_id, io_task1, mut client_incoming) = client
583 .add_test_connection(client_to_server_conn, cx.background())
584 .await;
585 let (server_to_client_conn_id, io_task2, mut server_incoming) = server
586 .add_test_connection(server_to_client_conn, cx.background())
587 .await;
588
589 executor.spawn(io_task1).detach();
590 executor.spawn(io_task2).detach();
591
592 executor
593 .spawn(async move {
594 let request1 = server_incoming
595 .next()
596 .await
597 .unwrap()
598 .into_any()
599 .downcast::<TypedEnvelope<proto::Ping>>()
600 .unwrap();
601 let request2 = server_incoming
602 .next()
603 .await
604 .unwrap()
605 .into_any()
606 .downcast::<TypedEnvelope<proto::Ping>>()
607 .unwrap();
608
609 server
610 .send(
611 server_to_client_conn_id,
612 proto::Error {
613 message: "message 1".to_string(),
614 },
615 )
616 .unwrap();
617 server
618 .send(
619 server_to_client_conn_id,
620 proto::Error {
621 message: "message 2".to_string(),
622 },
623 )
624 .unwrap();
625 server.respond(request1.receipt(), proto::Ack {}).unwrap();
626 server.respond(request2.receipt(), proto::Ack {}).unwrap();
627
628 // Prevent the connection from being dropped
629 server_incoming.next().await;
630 })
631 .detach();
632
633 let events = Arc::new(Mutex::new(Vec::new()));
634
635 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
636 let request1_task = executor.spawn(request1);
637 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
638 let request2_task = executor.spawn({
639 let events = events.clone();
640 async move {
641 request2.await.unwrap();
642 events.lock().push("response 2".to_string());
643 }
644 });
645
646 executor
647 .spawn({
648 let events = events.clone();
649 async move {
650 let incoming1 = client_incoming
651 .next()
652 .await
653 .unwrap()
654 .into_any()
655 .downcast::<TypedEnvelope<proto::Error>>()
656 .unwrap();
657 events.lock().push(incoming1.payload.message);
658 let incoming2 = client_incoming
659 .next()
660 .await
661 .unwrap()
662 .into_any()
663 .downcast::<TypedEnvelope<proto::Error>>()
664 .unwrap();
665 events.lock().push(incoming2.payload.message);
666
667 // Prevent the connection from being dropped
668 client_incoming.next().await;
669 }
670 })
671 .detach();
672
673 // Allow the request to make some progress before dropping it.
674 cx.background().simulate_random_delay().await;
675 drop(request1_task);
676
677 request2_task.await;
678 assert_eq!(
679 &*events.lock(),
680 &[
681 "message 1".to_string(),
682 "message 2".to_string(),
683 "response 2".to_string()
684 ]
685 );
686 }
687
688 #[gpui::test(iterations = 50)]
689 async fn test_disconnect(cx: &mut TestAppContext) {
690 let executor = cx.foreground();
691
692 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
693
694 let client = Peer::new();
695 let (connection_id, io_handler, mut incoming) = client
696 .add_test_connection(client_conn, cx.background())
697 .await;
698
699 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
700 executor
701 .spawn(async move {
702 io_handler.await.ok();
703 io_ended_tx.send(()).await.unwrap();
704 })
705 .detach();
706
707 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
708 executor
709 .spawn(async move {
710 incoming.next().await;
711 messages_ended_tx.send(()).await.unwrap();
712 })
713 .detach();
714
715 client.disconnect(connection_id);
716
717 io_ended_rx.recv().await;
718 messages_ended_rx.recv().await;
719 assert!(server_conn
720 .send(WebSocketMessage::Binary(vec![]))
721 .await
722 .is_err());
723 }
724
725 #[gpui::test(iterations = 50)]
726 async fn test_io_error(cx: &mut TestAppContext) {
727 let executor = cx.foreground();
728 let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
729
730 let client = Peer::new();
731 let (connection_id, io_handler, mut incoming) = client
732 .add_test_connection(client_conn, cx.background())
733 .await;
734 executor.spawn(io_handler).detach();
735 executor
736 .spawn(async move { incoming.next().await })
737 .detach();
738
739 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
740 let _request = server_conn.rx.next().await.unwrap().unwrap();
741
742 drop(server_conn);
743 assert_eq!(
744 response.await.unwrap_err().to_string(),
745 "connection was closed"
746 );
747 }
748}