1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::{channel::oneshot, stream::BoxStream, FutureExt as _, StreamExt};
5use parking_lot::{Mutex, RwLock};
6use postage::{
7 barrier, mpsc,
8 prelude::{Sink as _, Stream as _},
9};
10use smol_timeout::TimeoutExt as _;
11use std::sync::atomic::Ordering::SeqCst;
12use std::{
13 collections::HashMap,
14 fmt,
15 future::Future,
16 marker::PhantomData,
17 sync::{
18 atomic::{self, AtomicU32},
19 Arc,
20 },
21 time::Duration,
22};
23
24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
25pub struct ConnectionId(pub u32);
26
27impl fmt::Display for ConnectionId {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 self.0.fmt(f)
30 }
31}
32
33#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
34pub struct PeerId(pub u32);
35
36impl fmt::Display for PeerId {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 self.0.fmt(f)
39 }
40}
41
42pub struct Receipt<T> {
43 pub sender_id: ConnectionId,
44 pub message_id: u32,
45 payload_type: PhantomData<T>,
46}
47
48impl<T> Clone for Receipt<T> {
49 fn clone(&self) -> Self {
50 Self {
51 sender_id: self.sender_id,
52 message_id: self.message_id,
53 payload_type: PhantomData,
54 }
55 }
56}
57
58impl<T> Copy for Receipt<T> {}
59
60pub struct TypedEnvelope<T> {
61 pub sender_id: ConnectionId,
62 pub original_sender_id: Option<PeerId>,
63 pub message_id: u32,
64 pub payload: T,
65}
66
67impl<T> TypedEnvelope<T> {
68 pub fn original_sender_id(&self) -> Result<PeerId> {
69 self.original_sender_id
70 .ok_or_else(|| anyhow!("missing original_sender_id"))
71 }
72}
73
74impl<T: RequestMessage> TypedEnvelope<T> {
75 pub fn receipt(&self) -> Receipt<T> {
76 Receipt {
77 sender_id: self.sender_id,
78 message_id: self.message_id,
79 payload_type: PhantomData,
80 }
81 }
82}
83
84pub struct Peer {
85 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
86 next_connection_id: AtomicU32,
87}
88
89#[derive(Clone)]
90pub struct ConnectionState {
91 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
92 next_message_id: Arc<AtomicU32>,
93 response_channels:
94 Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
95}
96
97const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
98
99impl Peer {
100 pub fn new() -> Arc<Self> {
101 Arc::new(Self {
102 connections: Default::default(),
103 next_connection_id: Default::default(),
104 })
105 }
106
107 pub async fn add_connection(
108 self: &Arc<Self>,
109 connection: Connection,
110 ) -> (
111 ConnectionId,
112 impl Future<Output = anyhow::Result<()>> + Send,
113 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
114 ) {
115 // For outgoing messages, use an unbounded channel so that application code
116 // can always send messages without yielding. For incoming messages, use a
117 // bounded channel so that other peers will receive backpressure if they send
118 // messages faster than this peer can process them.
119 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
120 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
121
122 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
123 let connection_state = ConnectionState {
124 outgoing_tx,
125 next_message_id: Default::default(),
126 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
127 };
128 let mut writer = MessageStream::new(connection.tx);
129 let mut reader = MessageStream::new(connection.rx);
130
131 let this = self.clone();
132 let response_channels = connection_state.response_channels.clone();
133 let handle_io = async move {
134 let result = 'outer: loop {
135 let read_message = reader.read_message().fuse();
136 futures::pin_mut!(read_message);
137 loop {
138 futures::select_biased! {
139 outgoing = outgoing_rx.next().fuse() => match outgoing {
140 Some(outgoing) => {
141 match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
142 None => break 'outer Err(anyhow!("timed out writing RPC message")),
143 Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
144 _ => {}
145 }
146 }
147 None => break 'outer Ok(()),
148 },
149 incoming = read_message => match incoming {
150 Ok(incoming) => {
151 if incoming_tx.send(incoming).await.is_err() {
152 break 'outer Ok(());
153 }
154 break;
155 }
156 Err(error) => {
157 break 'outer Err(error).context("received invalid RPC message")
158 }
159 },
160 }
161 }
162 };
163
164 response_channels.lock().take();
165 this.connections.write().remove(&connection_id);
166 result
167 };
168
169 let response_channels = connection_state.response_channels.clone();
170 self.connections
171 .write()
172 .insert(connection_id, connection_state);
173
174 let incoming_rx = incoming_rx.filter_map(move |incoming| {
175 let response_channels = response_channels.clone();
176 async move {
177 if let Some(responding_to) = incoming.responding_to {
178 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
179 if let Some(tx) = channel {
180 let mut requester_resumed = barrier::channel();
181 if let Err(error) = tx.send((incoming, requester_resumed.0)) {
182 log::debug!(
183 "received RPC but request future was dropped {:?}",
184 error.0
185 );
186 }
187 requester_resumed.1.recv().await;
188 } else {
189 log::warn!("received RPC response to unknown request {}", responding_to);
190 }
191
192 None
193 } else {
194 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
195 Some(envelope)
196 } else {
197 log::error!("unable to construct a typed envelope");
198 None
199 }
200 }
201 }
202 });
203 (connection_id, handle_io, incoming_rx.boxed())
204 }
205
206 pub fn disconnect(&self, connection_id: ConnectionId) {
207 self.connections.write().remove(&connection_id);
208 }
209
210 pub fn reset(&self) {
211 self.connections.write().clear();
212 }
213
214 pub fn request<T: RequestMessage>(
215 &self,
216 receiver_id: ConnectionId,
217 request: T,
218 ) -> impl Future<Output = Result<T::Response>> {
219 self.request_internal(None, receiver_id, request)
220 }
221
222 pub fn forward_request<T: RequestMessage>(
223 &self,
224 sender_id: ConnectionId,
225 receiver_id: ConnectionId,
226 request: T,
227 ) -> impl Future<Output = Result<T::Response>> {
228 self.request_internal(Some(sender_id), receiver_id, request)
229 }
230
231 pub fn request_internal<T: RequestMessage>(
232 &self,
233 original_sender_id: Option<ConnectionId>,
234 receiver_id: ConnectionId,
235 request: T,
236 ) -> impl Future<Output = Result<T::Response>> {
237 let (tx, rx) = oneshot::channel();
238 let send = self.connection_state(receiver_id).and_then(|connection| {
239 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
240 connection
241 .response_channels
242 .lock()
243 .as_mut()
244 .ok_or_else(|| anyhow!("connection was closed"))?
245 .insert(message_id, tx);
246 connection
247 .outgoing_tx
248 .unbounded_send(request.into_envelope(
249 message_id,
250 None,
251 original_sender_id.map(|id| id.0),
252 ))
253 .map_err(|_| anyhow!("connection was closed"))?;
254 Ok(())
255 });
256 async move {
257 send?;
258 let (response, _barrier) = rx.await.map_err(|_| anyhow!("connection was closed"))?;
259 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
260 Err(anyhow!("RPC request failed - {}", error.message))
261 } else {
262 T::Response::from_envelope(response)
263 .ok_or_else(|| anyhow!("received response of the wrong type"))
264 }
265 }
266 }
267
268 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
269 let connection = self.connection_state(receiver_id)?;
270 let message_id = connection
271 .next_message_id
272 .fetch_add(1, atomic::Ordering::SeqCst);
273 connection
274 .outgoing_tx
275 .unbounded_send(message.into_envelope(message_id, None, None))?;
276 Ok(())
277 }
278
279 pub fn forward_send<T: EnvelopedMessage>(
280 &self,
281 sender_id: ConnectionId,
282 receiver_id: ConnectionId,
283 message: T,
284 ) -> Result<()> {
285 let connection = self.connection_state(receiver_id)?;
286 let message_id = connection
287 .next_message_id
288 .fetch_add(1, atomic::Ordering::SeqCst);
289 connection
290 .outgoing_tx
291 .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
292 Ok(())
293 }
294
295 pub fn respond<T: RequestMessage>(
296 &self,
297 receipt: Receipt<T>,
298 response: T::Response,
299 ) -> Result<()> {
300 let connection = self.connection_state(receipt.sender_id)?;
301 let message_id = connection
302 .next_message_id
303 .fetch_add(1, atomic::Ordering::SeqCst);
304 connection
305 .outgoing_tx
306 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
307 Ok(())
308 }
309
310 pub fn respond_with_error<T: RequestMessage>(
311 &self,
312 receipt: Receipt<T>,
313 response: proto::Error,
314 ) -> Result<()> {
315 let connection = self.connection_state(receipt.sender_id)?;
316 let message_id = connection
317 .next_message_id
318 .fetch_add(1, atomic::Ordering::SeqCst);
319 connection
320 .outgoing_tx
321 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
322 Ok(())
323 }
324
325 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
326 let connections = self.connections.read();
327 let connection = connections
328 .get(&connection_id)
329 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
330 Ok(connection.clone())
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::TypedEnvelope;
338 use async_tungstenite::tungstenite::Message as WebSocketMessage;
339 use gpui::TestAppContext;
340
341 #[gpui::test(iterations = 50)]
342 async fn test_request_response(cx: &mut TestAppContext) {
343 let executor = cx.foreground();
344
345 // create 2 clients connected to 1 server
346 let server = Peer::new();
347 let client1 = Peer::new();
348 let client2 = Peer::new();
349
350 let (client1_to_server_conn, server_to_client_1_conn, _) =
351 Connection::in_memory(cx.background());
352 let (client1_conn_id, io_task1, client1_incoming) =
353 client1.add_connection(client1_to_server_conn).await;
354 let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
355
356 let (client2_to_server_conn, server_to_client_2_conn, _) =
357 Connection::in_memory(cx.background());
358 let (client2_conn_id, io_task3, client2_incoming) =
359 client2.add_connection(client2_to_server_conn).await;
360 let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
361
362 executor.spawn(io_task1).detach();
363 executor.spawn(io_task2).detach();
364 executor.spawn(io_task3).detach();
365 executor.spawn(io_task4).detach();
366 executor
367 .spawn(handle_messages(server_incoming1, server.clone()))
368 .detach();
369 executor
370 .spawn(handle_messages(client1_incoming, client1.clone()))
371 .detach();
372 executor
373 .spawn(handle_messages(server_incoming2, server.clone()))
374 .detach();
375 executor
376 .spawn(handle_messages(client2_incoming, client2.clone()))
377 .detach();
378
379 assert_eq!(
380 client1
381 .request(client1_conn_id, proto::Ping {},)
382 .await
383 .unwrap(),
384 proto::Ack {}
385 );
386
387 assert_eq!(
388 client2
389 .request(client2_conn_id, proto::Ping {},)
390 .await
391 .unwrap(),
392 proto::Ack {}
393 );
394
395 assert_eq!(
396 client1
397 .request(client1_conn_id, proto::Test { id: 1 },)
398 .await
399 .unwrap(),
400 proto::Test { id: 1 }
401 );
402
403 assert_eq!(
404 client2
405 .request(client2_conn_id, proto::Test { id: 2 })
406 .await
407 .unwrap(),
408 proto::Test { id: 2 }
409 );
410
411 client1.disconnect(client1_conn_id);
412 client2.disconnect(client1_conn_id);
413
414 async fn handle_messages(
415 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
416 peer: Arc<Peer>,
417 ) -> Result<()> {
418 while let Some(envelope) = messages.next().await {
419 let envelope = envelope.into_any();
420 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
421 let receipt = envelope.receipt();
422 peer.respond(receipt, proto::Ack {})?
423 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
424 {
425 peer.respond(envelope.receipt(), envelope.payload.clone())?
426 } else {
427 panic!("unknown message type");
428 }
429 }
430
431 Ok(())
432 }
433 }
434
435 #[gpui::test(iterations = 50)]
436 async fn test_order_of_response_and_incoming(cx: &mut TestAppContext) {
437 let executor = cx.foreground();
438 let server = Peer::new();
439 let client = Peer::new();
440
441 let (client_to_server_conn, server_to_client_conn, _) =
442 Connection::in_memory(cx.background());
443 let (client_to_server_conn_id, io_task1, mut client_incoming) =
444 client.add_connection(client_to_server_conn).await;
445 let (server_to_client_conn_id, io_task2, mut server_incoming) =
446 server.add_connection(server_to_client_conn).await;
447
448 executor.spawn(io_task1).detach();
449 executor.spawn(io_task2).detach();
450
451 executor
452 .spawn(async move {
453 let request = server_incoming
454 .next()
455 .await
456 .unwrap()
457 .into_any()
458 .downcast::<TypedEnvelope<proto::Ping>>()
459 .unwrap();
460
461 server
462 .send(
463 server_to_client_conn_id,
464 proto::Error {
465 message: "message 1".to_string(),
466 },
467 )
468 .unwrap();
469 server
470 .send(
471 server_to_client_conn_id,
472 proto::Error {
473 message: "message 2".to_string(),
474 },
475 )
476 .unwrap();
477 server.respond(request.receipt(), proto::Ack {}).unwrap();
478
479 // Prevent the connection from being dropped
480 server_incoming.next().await;
481 })
482 .detach();
483
484 let events = Arc::new(Mutex::new(Vec::new()));
485
486 let response = client.request(client_to_server_conn_id, proto::Ping {});
487 let response_task = executor.spawn({
488 let events = events.clone();
489 async move {
490 response.await.unwrap();
491 events.lock().push("response".to_string());
492 }
493 });
494
495 executor
496 .spawn({
497 let events = events.clone();
498 async move {
499 let incoming1 = client_incoming
500 .next()
501 .await
502 .unwrap()
503 .into_any()
504 .downcast::<TypedEnvelope<proto::Error>>()
505 .unwrap();
506 events.lock().push(incoming1.payload.message);
507 let incoming2 = client_incoming
508 .next()
509 .await
510 .unwrap()
511 .into_any()
512 .downcast::<TypedEnvelope<proto::Error>>()
513 .unwrap();
514 events.lock().push(incoming2.payload.message);
515
516 // Prevent the connection from being dropped
517 client_incoming.next().await;
518 }
519 })
520 .detach();
521
522 response_task.await;
523 assert_eq!(
524 &*events.lock(),
525 &[
526 "message 1".to_string(),
527 "message 2".to_string(),
528 "response".to_string()
529 ]
530 );
531 }
532
533 #[gpui::test(iterations = 50)]
534 async fn test_dropping_request_before_completion(cx: &mut TestAppContext) {
535 let executor = cx.foreground();
536 let server = Peer::new();
537 let client = Peer::new();
538
539 let (client_to_server_conn, server_to_client_conn, _) =
540 Connection::in_memory(cx.background());
541 let (client_to_server_conn_id, io_task1, mut client_incoming) =
542 client.add_connection(client_to_server_conn).await;
543 let (server_to_client_conn_id, io_task2, mut server_incoming) =
544 server.add_connection(server_to_client_conn).await;
545
546 executor.spawn(io_task1).detach();
547 executor.spawn(io_task2).detach();
548
549 executor
550 .spawn(async move {
551 let request1 = server_incoming
552 .next()
553 .await
554 .unwrap()
555 .into_any()
556 .downcast::<TypedEnvelope<proto::Ping>>()
557 .unwrap();
558 let request2 = server_incoming
559 .next()
560 .await
561 .unwrap()
562 .into_any()
563 .downcast::<TypedEnvelope<proto::Ping>>()
564 .unwrap();
565
566 server
567 .send(
568 server_to_client_conn_id,
569 proto::Error {
570 message: "message 1".to_string(),
571 },
572 )
573 .unwrap();
574 server
575 .send(
576 server_to_client_conn_id,
577 proto::Error {
578 message: "message 2".to_string(),
579 },
580 )
581 .unwrap();
582 server.respond(request1.receipt(), proto::Ack {}).unwrap();
583 server.respond(request2.receipt(), proto::Ack {}).unwrap();
584
585 // Prevent the connection from being dropped
586 server_incoming.next().await;
587 })
588 .detach();
589
590 let events = Arc::new(Mutex::new(Vec::new()));
591
592 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
593 let request1_task = executor.spawn(request1);
594 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
595 let request2_task = executor.spawn({
596 let events = events.clone();
597 async move {
598 request2.await.unwrap();
599 events.lock().push("response 2".to_string());
600 }
601 });
602
603 executor
604 .spawn({
605 let events = events.clone();
606 async move {
607 let incoming1 = client_incoming
608 .next()
609 .await
610 .unwrap()
611 .into_any()
612 .downcast::<TypedEnvelope<proto::Error>>()
613 .unwrap();
614 events.lock().push(incoming1.payload.message);
615 let incoming2 = client_incoming
616 .next()
617 .await
618 .unwrap()
619 .into_any()
620 .downcast::<TypedEnvelope<proto::Error>>()
621 .unwrap();
622 events.lock().push(incoming2.payload.message);
623
624 // Prevent the connection from being dropped
625 client_incoming.next().await;
626 }
627 })
628 .detach();
629
630 // Allow the request to make some progress before dropping it.
631 cx.background().simulate_random_delay().await;
632 drop(request1_task);
633
634 request2_task.await;
635 assert_eq!(
636 &*events.lock(),
637 &[
638 "message 1".to_string(),
639 "message 2".to_string(),
640 "response 2".to_string()
641 ]
642 );
643 }
644
645 #[gpui::test(iterations = 50)]
646 async fn test_disconnect(cx: &mut TestAppContext) {
647 let executor = cx.foreground();
648
649 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
650
651 let client = Peer::new();
652 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
653
654 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
655 executor
656 .spawn(async move {
657 io_handler.await.ok();
658 io_ended_tx.send(()).await.unwrap();
659 })
660 .detach();
661
662 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
663 executor
664 .spawn(async move {
665 incoming.next().await;
666 messages_ended_tx.send(()).await.unwrap();
667 })
668 .detach();
669
670 client.disconnect(connection_id);
671
672 io_ended_rx.recv().await;
673 messages_ended_rx.recv().await;
674 assert!(server_conn
675 .send(WebSocketMessage::Binary(vec![]))
676 .await
677 .is_err());
678 }
679
680 #[gpui::test(iterations = 50)]
681 async fn test_io_error(cx: &mut TestAppContext) {
682 let executor = cx.foreground();
683 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
684
685 let client = Peer::new();
686 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
687 executor.spawn(io_handler).detach();
688 executor
689 .spawn(async move { incoming.next().await })
690 .detach();
691
692 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
693 let _request = server_conn.rx.next().await.unwrap().unwrap();
694
695 drop(server_conn);
696 assert_eq!(
697 response.await.unwrap_err().to_string(),
698 "connection was closed"
699 );
700 }
701}