1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::stream::BoxStream;
5use futures::{FutureExt as _, StreamExt};
6use parking_lot::{Mutex, RwLock};
7use postage::{
8 mpsc,
9 prelude::{Sink as _, Stream as _},
10};
11use smol_timeout::TimeoutExt as _;
12use std::sync::atomic::Ordering::SeqCst;
13use std::{
14 collections::HashMap,
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28impl fmt::Display for ConnectionId {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 self.0.fmt(f)
31 }
32}
33
34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
35pub struct PeerId(pub u32);
36
37impl fmt::Display for PeerId {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 self.0.fmt(f)
40 }
41}
42
43pub struct Receipt<T> {
44 pub sender_id: ConnectionId,
45 pub message_id: u32,
46 payload_type: PhantomData<T>,
47}
48
49impl<T> Clone for Receipt<T> {
50 fn clone(&self) -> Self {
51 Self {
52 sender_id: self.sender_id,
53 message_id: self.message_id,
54 payload_type: PhantomData,
55 }
56 }
57}
58
59impl<T> Copy for Receipt<T> {}
60
61pub struct TypedEnvelope<T> {
62 pub sender_id: ConnectionId,
63 pub original_sender_id: Option<PeerId>,
64 pub message_id: u32,
65 pub payload: T,
66}
67
68impl<T> TypedEnvelope<T> {
69 pub fn original_sender_id(&self) -> Result<PeerId> {
70 self.original_sender_id
71 .ok_or_else(|| anyhow!("missing original_sender_id"))
72 }
73}
74
75impl<T: RequestMessage> TypedEnvelope<T> {
76 pub fn receipt(&self) -> Receipt<T> {
77 Receipt {
78 sender_id: self.sender_id,
79 message_id: self.message_id,
80 payload_type: PhantomData,
81 }
82 }
83}
84
85pub struct Peer {
86 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
87 next_connection_id: AtomicU32,
88}
89
90#[derive(Clone)]
91pub struct ConnectionState {
92 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
93 next_message_id: Arc<AtomicU32>,
94 response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
95}
96
97const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
98
99impl Peer {
100 pub fn new() -> Arc<Self> {
101 Arc::new(Self {
102 connections: Default::default(),
103 next_connection_id: Default::default(),
104 })
105 }
106
107 pub async fn add_connection(
108 self: &Arc<Self>,
109 connection: Connection,
110 ) -> (
111 ConnectionId,
112 impl Future<Output = anyhow::Result<()>> + Send,
113 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
114 ) {
115 // For outgoing messages, use an unbounded channel so that application code
116 // can always send messages without yielding. For incoming messages, use a
117 // bounded channel so that other peers will receive backpressure if they send
118 // messages faster than this peer can process them.
119 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
120 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
121
122 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
123 let connection_state = ConnectionState {
124 outgoing_tx,
125 next_message_id: Default::default(),
126 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
127 };
128 let mut writer = MessageStream::new(connection.tx);
129 let mut reader = MessageStream::new(connection.rx);
130
131 let this = self.clone();
132 let response_channels = connection_state.response_channels.clone();
133 let handle_io = async move {
134 let result = 'outer: loop {
135 let read_message = reader.read_message().fuse();
136 futures::pin_mut!(read_message);
137 loop {
138 futures::select_biased! {
139 outgoing = outgoing_rx.next().fuse() => match outgoing {
140 Some(outgoing) => {
141 match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
142 None => break 'outer Err(anyhow!("timed out writing RPC message")),
143 Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
144 _ => {}
145 }
146 }
147 None => break 'outer Ok(()),
148 },
149 incoming = read_message => match incoming {
150 Ok(incoming) => {
151 if incoming_tx.send(incoming).await.is_err() {
152 break 'outer Ok(());
153 }
154 break;
155 }
156 Err(error) => {
157 break 'outer Err(error).context("received invalid RPC message")
158 }
159 },
160 }
161 }
162 };
163
164 response_channels.lock().take();
165 this.connections.write().remove(&connection_id);
166 result
167 };
168
169 let response_channels = connection_state.response_channels.clone();
170 self.connections
171 .write()
172 .insert(connection_id, connection_state);
173
174 let incoming_rx = incoming_rx.filter_map(move |incoming| {
175 let response_channels = response_channels.clone();
176 async move {
177 if let Some(responding_to) = incoming.responding_to {
178 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
179 if let Some(mut tx) = channel {
180 if let Err(error) = tx.send(incoming).await {
181 log::debug!(
182 "received RPC but request future was dropped {:?}",
183 error.0
184 );
185 }
186 } else {
187 log::warn!("received RPC response to unknown request {}", responding_to);
188 }
189
190 None
191 } else {
192 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
193 Some(envelope)
194 } else {
195 log::error!("unable to construct a typed envelope");
196 None
197 }
198 }
199 }
200 });
201 (connection_id, handle_io, incoming_rx.boxed())
202 }
203
204 pub fn disconnect(&self, connection_id: ConnectionId) {
205 self.connections.write().remove(&connection_id);
206 }
207
208 pub fn reset(&self) {
209 self.connections.write().clear();
210 }
211
212 pub fn request<T: RequestMessage>(
213 &self,
214 receiver_id: ConnectionId,
215 request: T,
216 ) -> impl Future<Output = Result<T::Response>> {
217 self.request_internal(None, receiver_id, request)
218 }
219
220 pub fn forward_request<T: RequestMessage>(
221 &self,
222 sender_id: ConnectionId,
223 receiver_id: ConnectionId,
224 request: T,
225 ) -> impl Future<Output = Result<T::Response>> {
226 self.request_internal(Some(sender_id), receiver_id, request)
227 }
228
229 pub fn request_internal<T: RequestMessage>(
230 &self,
231 original_sender_id: Option<ConnectionId>,
232 receiver_id: ConnectionId,
233 request: T,
234 ) -> impl Future<Output = Result<T::Response>> {
235 let (tx, mut rx) = mpsc::channel(1);
236 let send = self.connection_state(receiver_id).and_then(|connection| {
237 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
238 connection
239 .response_channels
240 .lock()
241 .as_mut()
242 .ok_or_else(|| anyhow!("connection was closed"))?
243 .insert(message_id, tx);
244 connection
245 .outgoing_tx
246 .unbounded_send(request.into_envelope(
247 message_id,
248 None,
249 original_sender_id.map(|id| id.0),
250 ))
251 .map_err(|_| anyhow!("connection was closed"))?;
252 Ok(())
253 });
254 async move {
255 send?;
256 let response = rx
257 .recv()
258 .await
259 .ok_or_else(|| anyhow!("connection was closed"))?;
260 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
261 Err(anyhow!("request failed").context(error.message.clone()))
262 } else {
263 T::Response::from_envelope(response)
264 .ok_or_else(|| anyhow!("received response of the wrong type"))
265 }
266 }
267 }
268
269 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
270 let connection = self.connection_state(receiver_id)?;
271 let message_id = connection
272 .next_message_id
273 .fetch_add(1, atomic::Ordering::SeqCst);
274 connection
275 .outgoing_tx
276 .unbounded_send(message.into_envelope(message_id, None, None))?;
277 Ok(())
278 }
279
280 pub fn forward_send<T: EnvelopedMessage>(
281 &self,
282 sender_id: ConnectionId,
283 receiver_id: ConnectionId,
284 message: T,
285 ) -> Result<()> {
286 let connection = self.connection_state(receiver_id)?;
287 let message_id = connection
288 .next_message_id
289 .fetch_add(1, atomic::Ordering::SeqCst);
290 connection
291 .outgoing_tx
292 .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
293 Ok(())
294 }
295
296 pub fn respond<T: RequestMessage>(
297 &self,
298 receipt: Receipt<T>,
299 response: T::Response,
300 ) -> Result<()> {
301 let connection = self.connection_state(receipt.sender_id)?;
302 let message_id = connection
303 .next_message_id
304 .fetch_add(1, atomic::Ordering::SeqCst);
305 connection
306 .outgoing_tx
307 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
308 Ok(())
309 }
310
311 pub fn respond_with_error<T: RequestMessage>(
312 &self,
313 receipt: Receipt<T>,
314 response: proto::Error,
315 ) -> Result<()> {
316 let connection = self.connection_state(receipt.sender_id)?;
317 let message_id = connection
318 .next_message_id
319 .fetch_add(1, atomic::Ordering::SeqCst);
320 connection
321 .outgoing_tx
322 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
323 Ok(())
324 }
325
326 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
327 let connections = self.connections.read();
328 let connection = connections
329 .get(&connection_id)
330 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
331 Ok(connection.clone())
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::TypedEnvelope;
339 use async_tungstenite::tungstenite::Message as WebSocketMessage;
340 use gpui::TestAppContext;
341
342 #[gpui::test(iterations = 50)]
343 async fn test_request_response(cx: TestAppContext) {
344 let executor = cx.foreground();
345
346 // create 2 clients connected to 1 server
347 let server = Peer::new();
348 let client1 = Peer::new();
349 let client2 = Peer::new();
350
351 let (client1_to_server_conn, server_to_client_1_conn, _) =
352 Connection::in_memory(cx.background());
353 let (client1_conn_id, io_task1, client1_incoming) =
354 client1.add_connection(client1_to_server_conn).await;
355 let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
356
357 let (client2_to_server_conn, server_to_client_2_conn, _) =
358 Connection::in_memory(cx.background());
359 let (client2_conn_id, io_task3, client2_incoming) =
360 client2.add_connection(client2_to_server_conn).await;
361 let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
362
363 executor.spawn(io_task1).detach();
364 executor.spawn(io_task2).detach();
365 executor.spawn(io_task3).detach();
366 executor.spawn(io_task4).detach();
367 executor
368 .spawn(handle_messages(server_incoming1, server.clone()))
369 .detach();
370 executor
371 .spawn(handle_messages(client1_incoming, client1.clone()))
372 .detach();
373 executor
374 .spawn(handle_messages(server_incoming2, server.clone()))
375 .detach();
376 executor
377 .spawn(handle_messages(client2_incoming, client2.clone()))
378 .detach();
379
380 assert_eq!(
381 client1
382 .request(client1_conn_id, proto::Ping {},)
383 .await
384 .unwrap(),
385 proto::Ack {}
386 );
387
388 assert_eq!(
389 client2
390 .request(client2_conn_id, proto::Ping {},)
391 .await
392 .unwrap(),
393 proto::Ack {}
394 );
395
396 assert_eq!(
397 client1
398 .request(client1_conn_id, proto::Test { id: 1 },)
399 .await
400 .unwrap(),
401 proto::Test { id: 1 }
402 );
403
404 assert_eq!(
405 client2
406 .request(client2_conn_id, proto::Test { id: 2 })
407 .await
408 .unwrap(),
409 proto::Test { id: 2 }
410 );
411
412 client1.disconnect(client1_conn_id);
413 client2.disconnect(client1_conn_id);
414
415 async fn handle_messages(
416 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
417 peer: Arc<Peer>,
418 ) -> Result<()> {
419 while let Some(envelope) = messages.next().await {
420 let envelope = envelope.into_any();
421 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
422 let receipt = envelope.receipt();
423 peer.respond(receipt, proto::Ack {})?
424 } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
425 {
426 peer.respond(envelope.receipt(), envelope.payload.clone())?
427 } else {
428 panic!("unknown message type");
429 }
430 }
431
432 Ok(())
433 }
434 }
435
436 #[gpui::test(iterations = 50)]
437 async fn test_order_of_response_and_incoming(cx: TestAppContext) {
438 let executor = cx.foreground();
439 let server = Peer::new();
440 let client = Peer::new();
441
442 let (client_to_server_conn, server_to_client_conn, _) =
443 Connection::in_memory(cx.background());
444 let (client_to_server_conn_id, io_task1, mut client_incoming) =
445 client.add_connection(client_to_server_conn).await;
446 let (server_to_client_conn_id, io_task2, mut server_incoming) =
447 server.add_connection(server_to_client_conn).await;
448
449 executor.spawn(io_task1).detach();
450 executor.spawn(io_task2).detach();
451
452 executor
453 .spawn(async move {
454 let request = server_incoming
455 .next()
456 .await
457 .unwrap()
458 .into_any()
459 .downcast::<TypedEnvelope<proto::Ping>>()
460 .unwrap();
461
462 server
463 .send(
464 server_to_client_conn_id,
465 proto::Error {
466 message: "message 1".to_string(),
467 },
468 )
469 .unwrap();
470 server
471 .send(
472 server_to_client_conn_id,
473 proto::Error {
474 message: "message 2".to_string(),
475 },
476 )
477 .unwrap();
478 server.respond(request.receipt(), proto::Ack {}).unwrap();
479
480 // Prevent the connection from being dropped
481 server_incoming.next().await;
482 })
483 .detach();
484
485 let events = Arc::new(Mutex::new(Vec::new()));
486
487 let response = client.request(client_to_server_conn_id, proto::Ping {});
488 let response_task = executor.spawn({
489 let events = events.clone();
490 async move {
491 response.await.unwrap();
492 events.lock().push("response".to_string());
493 }
494 });
495
496 executor
497 .spawn({
498 let events = events.clone();
499 async move {
500 let incoming1 = client_incoming
501 .next()
502 .await
503 .unwrap()
504 .into_any()
505 .downcast::<TypedEnvelope<proto::Error>>()
506 .unwrap();
507 events.lock().push(incoming1.payload.message);
508 let incoming2 = client_incoming
509 .next()
510 .await
511 .unwrap()
512 .into_any()
513 .downcast::<TypedEnvelope<proto::Error>>()
514 .unwrap();
515 events.lock().push(incoming2.payload.message);
516
517 // Prevent the connection from being dropped
518 client_incoming.next().await;
519 }
520 })
521 .detach();
522
523 response_task.await;
524 assert_eq!(
525 &*events.lock(),
526 &[
527 "message 1".to_string(),
528 "message 2".to_string(),
529 "response".to_string()
530 ]
531 );
532 }
533
534 #[gpui::test(iterations = 50)]
535 async fn test_dropping_request_before_completion(cx: TestAppContext) {
536 let executor = cx.foreground();
537 let server = Peer::new();
538 let client = Peer::new();
539
540 let (client_to_server_conn, server_to_client_conn, _) =
541 Connection::in_memory(cx.background());
542 let (client_to_server_conn_id, io_task1, mut client_incoming) =
543 client.add_connection(client_to_server_conn).await;
544 let (server_to_client_conn_id, io_task2, mut server_incoming) =
545 server.add_connection(server_to_client_conn).await;
546
547 executor.spawn(io_task1).detach();
548 executor.spawn(io_task2).detach();
549
550 executor
551 .spawn(async move {
552 let request1 = server_incoming
553 .next()
554 .await
555 .unwrap()
556 .into_any()
557 .downcast::<TypedEnvelope<proto::Ping>>()
558 .unwrap();
559 let request2 = server_incoming
560 .next()
561 .await
562 .unwrap()
563 .into_any()
564 .downcast::<TypedEnvelope<proto::Ping>>()
565 .unwrap();
566
567 server
568 .send(
569 server_to_client_conn_id,
570 proto::Error {
571 message: "message 1".to_string(),
572 },
573 )
574 .unwrap();
575 server
576 .send(
577 server_to_client_conn_id,
578 proto::Error {
579 message: "message 2".to_string(),
580 },
581 )
582 .unwrap();
583 server.respond(request1.receipt(), proto::Ack {}).unwrap();
584 server.respond(request2.receipt(), proto::Ack {}).unwrap();
585
586 // Prevent the connection from being dropped
587 server_incoming.next().await;
588 })
589 .detach();
590
591 let events = Arc::new(Mutex::new(Vec::new()));
592
593 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
594 let request1_task = executor.spawn(request1);
595 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
596 let request2_task = executor.spawn({
597 let events = events.clone();
598 async move {
599 request2.await.unwrap();
600 events.lock().push("response 2".to_string());
601 }
602 });
603
604 executor
605 .spawn({
606 let events = events.clone();
607 async move {
608 let incoming1 = client_incoming
609 .next()
610 .await
611 .unwrap()
612 .into_any()
613 .downcast::<TypedEnvelope<proto::Error>>()
614 .unwrap();
615 events.lock().push(incoming1.payload.message);
616 let incoming2 = client_incoming
617 .next()
618 .await
619 .unwrap()
620 .into_any()
621 .downcast::<TypedEnvelope<proto::Error>>()
622 .unwrap();
623 events.lock().push(incoming2.payload.message);
624
625 // Prevent the connection from being dropped
626 client_incoming.next().await;
627 }
628 })
629 .detach();
630
631 // Allow the request to make some progress before dropping it.
632 cx.background().simulate_random_delay().await;
633 drop(request1_task);
634
635 request2_task.await;
636 assert_eq!(
637 &*events.lock(),
638 &[
639 "message 1".to_string(),
640 "message 2".to_string(),
641 "response 2".to_string()
642 ]
643 );
644 }
645
646 #[gpui::test(iterations = 50)]
647 async fn test_disconnect(cx: TestAppContext) {
648 let executor = cx.foreground();
649
650 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
651
652 let client = Peer::new();
653 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
654
655 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
656 executor
657 .spawn(async move {
658 io_handler.await.ok();
659 io_ended_tx.send(()).await.unwrap();
660 })
661 .detach();
662
663 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
664 executor
665 .spawn(async move {
666 incoming.next().await;
667 messages_ended_tx.send(()).await.unwrap();
668 })
669 .detach();
670
671 client.disconnect(connection_id);
672
673 io_ended_rx.recv().await;
674 messages_ended_rx.recv().await;
675 assert!(server_conn
676 .send(WebSocketMessage::Binary(vec![]))
677 .await
678 .is_err());
679 }
680
681 #[gpui::test(iterations = 50)]
682 async fn test_io_error(cx: TestAppContext) {
683 let executor = cx.foreground();
684 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
685
686 let client = Peer::new();
687 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
688 executor.spawn(io_handler).detach();
689 executor
690 .spawn(async move { incoming.next().await })
691 .detach();
692
693 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
694 let _request = server_conn.rx.next().await.unwrap().unwrap();
695
696 drop(server_conn);
697 assert_eq!(
698 response.await.unwrap_err().to_string(),
699 "connection was closed"
700 );
701 }
702}