1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::stream::BoxStream;
5use futures::{FutureExt as _, StreamExt};
6use parking_lot::{Mutex, RwLock};
7use postage::{
8 barrier, mpsc,
9 prelude::{Sink as _, Stream as _},
10};
11use smol_timeout::TimeoutExt as _;
12use std::sync::atomic::Ordering::SeqCst;
13use std::{
14 collections::HashMap,
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28impl fmt::Display for ConnectionId {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 self.0.fmt(f)
31 }
32}
33
34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
35pub struct PeerId(pub u32);
36
37impl fmt::Display for PeerId {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 self.0.fmt(f)
40 }
41}
42
43pub struct Receipt<T> {
44 pub sender_id: ConnectionId,
45 pub message_id: u32,
46 payload_type: PhantomData<T>,
47}
48
49impl<T> Clone for Receipt<T> {
50 fn clone(&self) -> Self {
51 Self {
52 sender_id: self.sender_id,
53 message_id: self.message_id,
54 payload_type: PhantomData,
55 }
56 }
57}
58
59impl<T> Copy for Receipt<T> {}
60
61pub struct TypedEnvelope<T> {
62 pub sender_id: ConnectionId,
63 pub original_sender_id: Option<PeerId>,
64 pub message_id: u32,
65 pub payload: T,
66}
67
68impl<T> TypedEnvelope<T> {
69 pub fn original_sender_id(&self) -> Result<PeerId> {
70 self.original_sender_id
71 .ok_or_else(|| anyhow!("missing original_sender_id"))
72 }
73}
74
75impl<T: RequestMessage> TypedEnvelope<T> {
76 pub fn receipt(&self) -> Receipt<T> {
77 Receipt {
78 sender_id: self.sender_id,
79 message_id: self.message_id,
80 payload_type: PhantomData,
81 }
82 }
83}
84
85pub struct Peer {
86 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
87 next_connection_id: AtomicU32,
88}
89
90#[derive(Clone)]
91pub struct ConnectionState {
92 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
93 next_message_id: Arc<AtomicU32>,
94 response_channels:
95 Arc<Mutex<Option<HashMap<u32, mpsc::Sender<(proto::Envelope, barrier::Sender)>>>>>,
96}
97
98const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
99
100impl Peer {
101 pub fn new() -> Arc<Self> {
102 Arc::new(Self {
103 connections: Default::default(),
104 next_connection_id: Default::default(),
105 })
106 }
107
108 pub async fn add_connection(
109 self: &Arc<Self>,
110 connection: Connection,
111 ) -> (
112 ConnectionId,
113 impl Future<Output = anyhow::Result<()>> + Send,
114 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
115 ) {
116 // For outgoing messages, use an unbounded channel so that application code
117 // can always send messages without yielding. For incoming messages, use a
118 // bounded channel so that other peers will receive backpressure if they send
119 // messages faster than this peer can process them.
120 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
121 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
122
123 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
124 let connection_state = ConnectionState {
125 outgoing_tx,
126 next_message_id: Default::default(),
127 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
128 };
129 let mut writer = MessageStream::new(connection.tx);
130 let mut reader = MessageStream::new(connection.rx);
131
132 let this = self.clone();
133 let response_channels = connection_state.response_channels.clone();
134 let handle_io = async move {
135 let result = 'outer: loop {
136 let read_message = reader.read_message().fuse();
137 futures::pin_mut!(read_message);
138 loop {
139 futures::select_biased! {
140 outgoing = outgoing_rx.next().fuse() => match outgoing {
141 Some(outgoing) => {
142 match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
143 None => break 'outer Err(anyhow!("timed out writing RPC message")),
144 Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
145 _ => {}
146 }
147 }
148 None => break 'outer Ok(()),
149 },
150 incoming = read_message => match incoming {
151 Ok(incoming) => {
152 if incoming_tx.send(incoming).await.is_err() {
153 break 'outer Ok(());
154 }
155 break;
156 }
157 Err(error) => {
158 break 'outer Err(error).context("received invalid RPC message")
159 }
160 },
161 }
162 }
163 };
164
165 response_channels.lock().take();
166 this.connections.write().remove(&connection_id);
167 result
168 };
169
170 let response_channels = connection_state.response_channels.clone();
171 self.connections
172 .write()
173 .insert(connection_id, connection_state);
174
175 let incoming_rx = incoming_rx.filter_map(move |incoming| {
176 let response_channels = response_channels.clone();
177 async move {
178 if let Some(responding_to) = incoming.responding_to {
179 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
180 if let Some(mut tx) = channel {
181 let mut requester_resumed = barrier::channel();
182 if let Err(error) = tx.send((incoming, requester_resumed.0)).await {
183 log::debug!(
184 "received RPC but request future was dropped {:?}",
185 error.0 .0
186 );
187 }
188 // Drop response channel before awaiting on the barrier. This allows the
189 // barrier to get dropped even if the request's future is dropped before it
190 // has a chance to observe the response.
191 drop(tx);
192 requester_resumed.1.recv().await;
193 } else {
194 log::warn!("received RPC response to unknown request {}", responding_to);
195 }
196
197 None
198 } else {
199 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
200 Some(envelope)
201 } else {
202 log::error!("unable to construct a typed envelope");
203 None
204 }
205 }
206 }
207 });
208 (connection_id, handle_io, incoming_rx.boxed())
209 }
210
211 pub fn disconnect(&self, connection_id: ConnectionId) {
212 self.connections.write().remove(&connection_id);
213 }
214
215 pub fn reset(&self) {
216 self.connections.write().clear();
217 }
218
219 pub fn request<T: RequestMessage>(
220 &self,
221 receiver_id: ConnectionId,
222 request: T,
223 ) -> impl Future<Output = Result<T::Response>> {
224 self.request_internal(None, receiver_id, request)
225 }
226
227 pub fn forward_request<T: RequestMessage>(
228 &self,
229 sender_id: ConnectionId,
230 receiver_id: ConnectionId,
231 request: T,
232 ) -> impl Future<Output = Result<T::Response>> {
233 self.request_internal(Some(sender_id), receiver_id, request)
234 }
235
236 pub fn request_internal<T: RequestMessage>(
237 &self,
238 original_sender_id: Option<ConnectionId>,
239 receiver_id: ConnectionId,
240 request: T,
241 ) -> impl Future<Output = Result<T::Response>> {
242 let (tx, mut rx) = mpsc::channel(1);
243 let send = self.connection_state(receiver_id).and_then(|connection| {
244 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
245 connection
246 .response_channels
247 .lock()
248 .as_mut()
249 .ok_or_else(|| anyhow!("connection was closed"))?
250 .insert(message_id, tx);
251 connection
252 .outgoing_tx
253 .unbounded_send(request.into_envelope(
254 message_id,
255 None,
256 original_sender_id.map(|id| id.0),
257 ))
258 .map_err(|_| anyhow!("connection was closed"))?;
259 Ok(())
260 });
261 async move {
262 send?;
263 let (response, _barrier) = rx
264 .recv()
265 .await
266 .ok_or_else(|| anyhow!("connection was closed"))?;
267 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
268 Err(anyhow!("RPC request failed - {}", error.message))
269 } else {
270 T::Response::from_envelope(response)
271 .ok_or_else(|| anyhow!("received response of the wrong type"))
272 }
273 }
274 }
275
276 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
277 let connection = self.connection_state(receiver_id)?;
278 let message_id = connection
279 .next_message_id
280 .fetch_add(1, atomic::Ordering::SeqCst);
281 connection
282 .outgoing_tx
283 .unbounded_send(message.into_envelope(message_id, None, None))?;
284 Ok(())
285 }
286
287 pub fn forward_send<T: EnvelopedMessage>(
288 &self,
289 sender_id: ConnectionId,
290 receiver_id: ConnectionId,
291 message: T,
292 ) -> Result<()> {
293 let connection = self.connection_state(receiver_id)?;
294 let message_id = connection
295 .next_message_id
296 .fetch_add(1, atomic::Ordering::SeqCst);
297 connection
298 .outgoing_tx
299 .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
300 Ok(())
301 }
302
303 pub fn respond<T: RequestMessage>(
304 &self,
305 receipt: Receipt<T>,
306 response: T::Response,
307 ) -> Result<()> {
308 let connection = self.connection_state(receipt.sender_id)?;
309 let message_id = connection
310 .next_message_id
311 .fetch_add(1, atomic::Ordering::SeqCst);
312 connection
313 .outgoing_tx
314 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
315 Ok(())
316 }
317
318 pub fn respond_with_error<T: RequestMessage>(
319 &self,
320 receipt: Receipt<T>,
321 response: proto::Error,
322 ) -> Result<()> {
323 let connection = self.connection_state(receipt.sender_id)?;
324 let message_id = connection
325 .next_message_id
326 .fetch_add(1, atomic::Ordering::SeqCst);
327 connection
328 .outgoing_tx
329 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
330 Ok(())
331 }
332
333 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
334 let connections = self.connections.read();
335 let connection = connections
336 .get(&connection_id)
337 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
338 Ok(connection.clone())
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::TypedEnvelope;
346 use async_tungstenite::tungstenite::Message as WebSocketMessage;
347 use gpui::TestAppContext;
348
349 #[gpui::test(iterations = 50)]
350 async fn test_request_response(cx: TestAppContext) {
351 let executor = cx.foreground();
352
353 // create 2 clients connected to 1 server
354 let server = Peer::new();
355 let client1 = Peer::new();
356 let client2 = Peer::new();
357
358 let (client1_to_server_conn, server_to_client_1_conn, _) =
359 Connection::in_memory(cx.background());
360 let (client1_conn_id, io_task1, client1_incoming) =
361 client1.add_connection(client1_to_server_conn).await;
362 let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
363
364 let (client2_to_server_conn, server_to_client_2_conn, _) =
365 Connection::in_memory(cx.background());
366 let (client2_conn_id, io_task3, client2_incoming) =
367 client2.add_connection(client2_to_server_conn).await;
368 let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
369
370 executor.spawn(io_task1).detach();
371 executor.spawn(io_task2).detach();
372 executor.spawn(io_task3).detach();
373 executor.spawn(io_task4).detach();
374 executor
375 .spawn(handle_messages(server_incoming1, server.clone()))
376 .detach();
377 executor
378 .spawn(handle_messages(client1_incoming, client1.clone()))
379 .detach();
380 executor
381 .spawn(handle_messages(server_incoming2, server.clone()))
382 .detach();
383 executor
384 .spawn(handle_messages(client2_incoming, client2.clone()))
385 .detach();
386
387 assert_eq!(
388 client1
389 .request(client1_conn_id, proto::Ping {},)
390 .await
391 .unwrap(),
392 proto::Ack {}
393 );
394
395 assert_eq!(
396 client2
397 .request(client2_conn_id, proto::Ping {},)
398 .await
399 .unwrap(),
400 proto::Ack {}
401 );
402
403 assert_eq!(
404 client1
405 .request(client1_conn_id, proto::Test { id: 1 },)
406 .await
407 .unwrap(),
408 proto::Test { id: 1 }
409 );
410
411 assert_eq!(
412 client2
413 .request(client2_conn_id, proto::Test { id: 2 })
414 .await
415 .unwrap(),
416 proto::Test { id: 2 }
417 );
418
419 client1.disconnect(client1_conn_id);
420 client2.disconnect(client1_conn_id);
421
422 async fn handle_messages(
423 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
424 peer: Arc<Peer>,
425 ) -> Result<()> {
426 while let Some(envelope) = messages.next().await {
427 let envelope = envelope.into_any();
428 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
429 let receipt = envelope.receipt();
430 peer.respond(receipt, proto::Ack {})?
431 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
432 {
433 peer.respond(envelope.receipt(), envelope.payload.clone())?
434 } else {
435 panic!("unknown message type");
436 }
437 }
438
439 Ok(())
440 }
441 }
442
443 #[gpui::test(iterations = 50)]
444 async fn test_order_of_response_and_incoming(cx: TestAppContext) {
445 let executor = cx.foreground();
446 let server = Peer::new();
447 let client = Peer::new();
448
449 let (client_to_server_conn, server_to_client_conn, _) =
450 Connection::in_memory(cx.background());
451 let (client_to_server_conn_id, io_task1, mut client_incoming) =
452 client.add_connection(client_to_server_conn).await;
453 let (server_to_client_conn_id, io_task2, mut server_incoming) =
454 server.add_connection(server_to_client_conn).await;
455
456 executor.spawn(io_task1).detach();
457 executor.spawn(io_task2).detach();
458
459 executor
460 .spawn(async move {
461 let request = server_incoming
462 .next()
463 .await
464 .unwrap()
465 .into_any()
466 .downcast::<TypedEnvelope<proto::Ping>>()
467 .unwrap();
468
469 server
470 .send(
471 server_to_client_conn_id,
472 proto::Error {
473 message: "message 1".to_string(),
474 },
475 )
476 .unwrap();
477 server
478 .send(
479 server_to_client_conn_id,
480 proto::Error {
481 message: "message 2".to_string(),
482 },
483 )
484 .unwrap();
485 server.respond(request.receipt(), proto::Ack {}).unwrap();
486
487 // Prevent the connection from being dropped
488 server_incoming.next().await;
489 })
490 .detach();
491
492 let events = Arc::new(Mutex::new(Vec::new()));
493
494 let response = client.request(client_to_server_conn_id, proto::Ping {});
495 let response_task = executor.spawn({
496 let events = events.clone();
497 async move {
498 response.await.unwrap();
499 events.lock().push("response".to_string());
500 }
501 });
502
503 executor
504 .spawn({
505 let events = events.clone();
506 async move {
507 let incoming1 = client_incoming
508 .next()
509 .await
510 .unwrap()
511 .into_any()
512 .downcast::<TypedEnvelope<proto::Error>>()
513 .unwrap();
514 events.lock().push(incoming1.payload.message);
515 let incoming2 = client_incoming
516 .next()
517 .await
518 .unwrap()
519 .into_any()
520 .downcast::<TypedEnvelope<proto::Error>>()
521 .unwrap();
522 events.lock().push(incoming2.payload.message);
523
524 // Prevent the connection from being dropped
525 client_incoming.next().await;
526 }
527 })
528 .detach();
529
530 response_task.await;
531 assert_eq!(
532 &*events.lock(),
533 &[
534 "message 1".to_string(),
535 "message 2".to_string(),
536 "response".to_string()
537 ]
538 );
539 }
540
541 #[gpui::test(iterations = 50)]
542 async fn test_dropping_request_before_completion(cx: TestAppContext) {
543 let executor = cx.foreground();
544 let server = Peer::new();
545 let client = Peer::new();
546
547 let (client_to_server_conn, server_to_client_conn, _) =
548 Connection::in_memory(cx.background());
549 let (client_to_server_conn_id, io_task1, mut client_incoming) =
550 client.add_connection(client_to_server_conn).await;
551 let (server_to_client_conn_id, io_task2, mut server_incoming) =
552 server.add_connection(server_to_client_conn).await;
553
554 executor.spawn(io_task1).detach();
555 executor.spawn(io_task2).detach();
556
557 executor
558 .spawn(async move {
559 let request1 = server_incoming
560 .next()
561 .await
562 .unwrap()
563 .into_any()
564 .downcast::<TypedEnvelope<proto::Ping>>()
565 .unwrap();
566 let request2 = server_incoming
567 .next()
568 .await
569 .unwrap()
570 .into_any()
571 .downcast::<TypedEnvelope<proto::Ping>>()
572 .unwrap();
573
574 server
575 .send(
576 server_to_client_conn_id,
577 proto::Error {
578 message: "message 1".to_string(),
579 },
580 )
581 .unwrap();
582 server
583 .send(
584 server_to_client_conn_id,
585 proto::Error {
586 message: "message 2".to_string(),
587 },
588 )
589 .unwrap();
590 server.respond(request1.receipt(), proto::Ack {}).unwrap();
591 server.respond(request2.receipt(), proto::Ack {}).unwrap();
592
593 // Prevent the connection from being dropped
594 server_incoming.next().await;
595 })
596 .detach();
597
598 let events = Arc::new(Mutex::new(Vec::new()));
599
600 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
601 let request1_task = executor.spawn(request1);
602 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
603 let request2_task = executor.spawn({
604 let events = events.clone();
605 async move {
606 request2.await.unwrap();
607 events.lock().push("response 2".to_string());
608 }
609 });
610
611 executor
612 .spawn({
613 let events = events.clone();
614 async move {
615 let incoming1 = client_incoming
616 .next()
617 .await
618 .unwrap()
619 .into_any()
620 .downcast::<TypedEnvelope<proto::Error>>()
621 .unwrap();
622 events.lock().push(incoming1.payload.message);
623 let incoming2 = client_incoming
624 .next()
625 .await
626 .unwrap()
627 .into_any()
628 .downcast::<TypedEnvelope<proto::Error>>()
629 .unwrap();
630 events.lock().push(incoming2.payload.message);
631
632 // Prevent the connection from being dropped
633 client_incoming.next().await;
634 }
635 })
636 .detach();
637
638 // Allow the request to make some progress before dropping it.
639 cx.background().simulate_random_delay().await;
640 drop(request1_task);
641
642 request2_task.await;
643 assert_eq!(
644 &*events.lock(),
645 &[
646 "message 1".to_string(),
647 "message 2".to_string(),
648 "response 2".to_string()
649 ]
650 );
651 }
652
653 #[gpui::test(iterations = 50)]
654 async fn test_disconnect(cx: TestAppContext) {
655 let executor = cx.foreground();
656
657 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
658
659 let client = Peer::new();
660 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
661
662 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
663 executor
664 .spawn(async move {
665 io_handler.await.ok();
666 io_ended_tx.send(()).await.unwrap();
667 })
668 .detach();
669
670 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
671 executor
672 .spawn(async move {
673 incoming.next().await;
674 messages_ended_tx.send(()).await.unwrap();
675 })
676 .detach();
677
678 client.disconnect(connection_id);
679
680 io_ended_rx.recv().await;
681 messages_ended_rx.recv().await;
682 assert!(server_conn
683 .send(WebSocketMessage::Binary(vec![]))
684 .await
685 .is_err());
686 }
687
688 #[gpui::test(iterations = 50)]
689 async fn test_io_error(cx: TestAppContext) {
690 let executor = cx.foreground();
691 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
692
693 let client = Peer::new();
694 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
695 executor.spawn(io_handler).detach();
696 executor
697 .spawn(async move { incoming.next().await })
698 .detach();
699
700 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
701 let _request = server_conn.rx.next().await.unwrap().unwrap();
702
703 drop(server_conn);
704 assert_eq!(
705 response.await.unwrap_err().to_string(),
706 "connection was closed"
707 );
708 }
709}