1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::stream::BoxStream;
5use futures::{FutureExt as _, StreamExt};
6use parking_lot::{Mutex, RwLock};
7use postage::{
8 barrier, mpsc,
9 prelude::{Sink as _, Stream as _},
10};
11use smol_timeout::TimeoutExt as _;
12use std::sync::atomic::Ordering::SeqCst;
13use std::{
14 collections::HashMap,
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28impl fmt::Display for ConnectionId {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 self.0.fmt(f)
31 }
32}
33
34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
35pub struct PeerId(pub u32);
36
37impl fmt::Display for PeerId {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 self.0.fmt(f)
40 }
41}
42
43pub struct Receipt<T> {
44 pub sender_id: ConnectionId,
45 pub message_id: u32,
46 payload_type: PhantomData<T>,
47}
48
49impl<T> Clone for Receipt<T> {
50 fn clone(&self) -> Self {
51 Self {
52 sender_id: self.sender_id,
53 message_id: self.message_id,
54 payload_type: PhantomData,
55 }
56 }
57}
58
59impl<T> Copy for Receipt<T> {}
60
61pub struct TypedEnvelope<T> {
62 pub sender_id: ConnectionId,
63 pub original_sender_id: Option<PeerId>,
64 pub message_id: u32,
65 pub payload: T,
66}
67
68impl<T> TypedEnvelope<T> {
69 pub fn original_sender_id(&self) -> Result<PeerId> {
70 self.original_sender_id
71 .ok_or_else(|| anyhow!("missing original_sender_id"))
72 }
73}
74
75impl<T: RequestMessage> TypedEnvelope<T> {
76 pub fn receipt(&self) -> Receipt<T> {
77 Receipt {
78 sender_id: self.sender_id,
79 message_id: self.message_id,
80 payload_type: PhantomData,
81 }
82 }
83}
84
85pub struct Peer {
86 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
87 next_connection_id: AtomicU32,
88}
89
90#[derive(Clone)]
91pub struct ConnectionState {
92 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
93 next_message_id: Arc<AtomicU32>,
94 response_channels:
95 Arc<Mutex<Option<HashMap<u32, mpsc::Sender<(proto::Envelope, barrier::Sender)>>>>>,
96}
97
98const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
99
100impl Peer {
101 pub fn new() -> Arc<Self> {
102 Arc::new(Self {
103 connections: Default::default(),
104 next_connection_id: Default::default(),
105 })
106 }
107
108 pub async fn add_connection(
109 self: &Arc<Self>,
110 connection: Connection,
111 ) -> (
112 ConnectionId,
113 impl Future<Output = anyhow::Result<()>> + Send,
114 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
115 ) {
116 // For outgoing messages, use an unbounded channel so that application code
117 // can always send messages without yielding. For incoming messages, use a
118 // bounded channel so that other peers will receive backpressure if they send
119 // messages faster than this peer can process them.
120 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
121 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
122
123 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
124 let connection_state = ConnectionState {
125 outgoing_tx,
126 next_message_id: Default::default(),
127 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
128 };
129 let mut writer = MessageStream::new(connection.tx);
130 let mut reader = MessageStream::new(connection.rx);
131
132 let this = self.clone();
133 let response_channels = connection_state.response_channels.clone();
134 let handle_io = async move {
135 let result = 'outer: loop {
136 let read_message = reader.read_message().fuse();
137 futures::pin_mut!(read_message);
138 loop {
139 futures::select_biased! {
140 outgoing = outgoing_rx.next().fuse() => match outgoing {
141 Some(outgoing) => {
142 match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
143 None => break 'outer Err(anyhow!("timed out writing RPC message")),
144 Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
145 _ => {}
146 }
147 }
148 None => break 'outer Ok(()),
149 },
150 incoming = read_message => match incoming {
151 Ok(incoming) => {
152 if incoming_tx.send(incoming).await.is_err() {
153 break 'outer Ok(());
154 }
155 break;
156 }
157 Err(error) => {
158 break 'outer Err(error).context("received invalid RPC message")
159 }
160 },
161 }
162 }
163 };
164
165 response_channels.lock().take();
166 this.connections.write().remove(&connection_id);
167 result
168 };
169
170 let response_channels = connection_state.response_channels.clone();
171 self.connections
172 .write()
173 .insert(connection_id, connection_state);
174
175 let incoming_rx = incoming_rx.filter_map(move |incoming| {
176 let response_channels = response_channels.clone();
177 async move {
178 if let Some(responding_to) = incoming.responding_to {
179 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
180 if let Some(mut tx) = channel {
181 let mut requester_resumed = barrier::channel();
182 if let Err(error) = tx.send((incoming, requester_resumed.0)).await {
183 log::debug!(
184 "received RPC but request future was dropped {:?}",
185 error.0 .0
186 );
187 }
188 // Drop response channel before awaiting on the barrier. This allows the
189 // barrier to get dropped even if the request's future is dropped before it
190 // has a chance to observe the response.
191 drop(tx);
192 requester_resumed.1.recv().await;
193 } else {
194 log::warn!("received RPC response to unknown request {}", responding_to);
195 }
196
197 None
198 } else {
199 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
200 Some(envelope)
201 } else {
202 log::error!("unable to construct a typed envelope");
203 None
204 }
205 }
206 }
207 });
208 (connection_id, handle_io, incoming_rx.boxed())
209 }
210
211 pub fn disconnect(&self, connection_id: ConnectionId) {
212 self.connections.write().remove(&connection_id);
213 }
214
215 pub fn reset(&self) {
216 self.connections.write().clear();
217 }
218
219 pub fn request<T: RequestMessage>(
220 &self,
221 receiver_id: ConnectionId,
222 request: T,
223 ) -> impl Future<Output = Result<T::Response>> {
224 self.request_internal(None, receiver_id, request)
225 }
226
227 pub fn forward_request<T: RequestMessage>(
228 &self,
229 sender_id: ConnectionId,
230 receiver_id: ConnectionId,
231 request: T,
232 ) -> impl Future<Output = Result<T::Response>> {
233 self.request_internal(Some(sender_id), receiver_id, request)
234 }
235
236 pub fn request_internal<T: RequestMessage>(
237 &self,
238 original_sender_id: Option<ConnectionId>,
239 receiver_id: ConnectionId,
240 request: T,
241 ) -> impl Future<Output = Result<T::Response>> {
242 let (tx, mut rx) = mpsc::channel(1);
243 let send = self.connection_state(receiver_id).and_then(|connection| {
244 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
245 connection
246 .response_channels
247 .lock()
248 .as_mut()
249 .ok_or_else(|| anyhow!("connection was closed"))?
250 .insert(message_id, tx);
251 connection
252 .outgoing_tx
253 .unbounded_send(request.into_envelope(
254 message_id,
255 None,
256 original_sender_id.map(|id| id.0),
257 ))
258 .map_err(|_| anyhow!("connection was closed"))?;
259 Ok(())
260 });
261 async move {
262 send?;
263 let (response, _barrier) = rx
264 .recv()
265 .await
266 .ok_or_else(|| anyhow!("connection was closed"))?;
267 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
268 Err(anyhow!("request failed").context(error.message.clone()))
269 } else {
270 T::Response::from_envelope(response)
271 .ok_or_else(|| anyhow!("received response of the wrong type"))
272 }
273 }
274 }
275
276 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
277 let connection = self.connection_state(receiver_id)?;
278 let message_id = connection
279 .next_message_id
280 .fetch_add(1, atomic::Ordering::SeqCst);
281 connection
282 .outgoing_tx
283 .unbounded_send(message.into_envelope(message_id, None, None))?;
284 Ok(())
285 }
286
287 pub fn forward_send<T: EnvelopedMessage>(
288 &self,
289 sender_id: ConnectionId,
290 receiver_id: ConnectionId,
291 message: T,
292 ) -> Result<()> {
293 let connection = self.connection_state(receiver_id)?;
294 let message_id = connection
295 .next_message_id
296 .fetch_add(1, atomic::Ordering::SeqCst);
297 connection
298 .outgoing_tx
299 .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
300 Ok(())
301 }
302
303 pub fn respond<T: RequestMessage>(
304 &self,
305 receipt: Receipt<T>,
306 response: T::Response,
307 ) -> Result<()> {
308 let connection = self.connection_state(receipt.sender_id)?;
309 let message_id = connection
310 .next_message_id
311 .fetch_add(1, atomic::Ordering::SeqCst);
312 connection
313 .outgoing_tx
314 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
315 Ok(())
316 }
317
318 pub fn respond_with_error<T: RequestMessage>(
319 &self,
320 receipt: Receipt<T>,
321 response: proto::Error,
322 ) -> Result<()> {
323 let connection = self.connection_state(receipt.sender_id)?;
324 let message_id = connection
325 .next_message_id
326 .fetch_add(1, atomic::Ordering::SeqCst);
327 connection
328 .outgoing_tx
329 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
330 Ok(())
331 }
332
333 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
334 let connections = self.connections.read();
335 let connection = connections
336 .get(&connection_id)
337 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
338 Ok(connection.clone())
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use crate::TypedEnvelope;
346 use async_tungstenite::tungstenite::Message as WebSocketMessage;
347 use gpui::TestAppContext;
348
349 #[gpui::test(iterations = 50)]
350 async fn test_request_response(cx: TestAppContext) {
351 let executor = cx.foreground();
352
353 // create 2 clients connected to 1 server
354 let server = Peer::new();
355 let client1 = Peer::new();
356 let client2 = Peer::new();
357
358 let (client1_to_server_conn, server_to_client_1_conn, _) =
359 Connection::in_memory(cx.background());
360 let (client1_conn_id, io_task1, client1_incoming) =
361 client1.add_connection(client1_to_server_conn).await;
362 let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
363
364 let (client2_to_server_conn, server_to_client_2_conn, _) =
365 Connection::in_memory(cx.background());
366 let (client2_conn_id, io_task3, client2_incoming) =
367 client2.add_connection(client2_to_server_conn).await;
368 let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
369
370 executor.spawn(io_task1).detach();
371 executor.spawn(io_task2).detach();
372 executor.spawn(io_task3).detach();
373 executor.spawn(io_task4).detach();
374 executor
375 .spawn(handle_messages(server_incoming1, server.clone()))
376 .detach();
377 executor
378 .spawn(handle_messages(client1_incoming, client1.clone()))
379 .detach();
380 executor
381 .spawn(handle_messages(server_incoming2, server.clone()))
382 .detach();
383 executor
384 .spawn(handle_messages(client2_incoming, client2.clone()))
385 .detach();
386
387 assert_eq!(
388 client1
389 .request(client1_conn_id, proto::Ping {},)
390 .await
391 .unwrap(),
392 proto::Ack {}
393 );
394
395 assert_eq!(
396 client2
397 .request(client2_conn_id, proto::Ping {},)
398 .await
399 .unwrap(),
400 proto::Ack {}
401 );
402
403 assert_eq!(
404 client1
405 .request(
406 client1_conn_id,
407 proto::OpenBuffer {
408 project_id: 0,
409 worktree_id: 1,
410 path: "path/one".to_string(),
411 },
412 )
413 .await
414 .unwrap(),
415 proto::OpenBufferResponse {
416 buffer: Some(proto::Buffer {
417 variant: Some(proto::buffer::Variant::Id(0))
418 }),
419 }
420 );
421
422 assert_eq!(
423 client2
424 .request(
425 client2_conn_id,
426 proto::OpenBuffer {
427 project_id: 0,
428 worktree_id: 2,
429 path: "path/two".to_string(),
430 },
431 )
432 .await
433 .unwrap(),
434 proto::OpenBufferResponse {
435 buffer: Some(proto::Buffer {
436 variant: Some(proto::buffer::Variant::Id(1))
437 })
438 }
439 );
440
441 client1.disconnect(client1_conn_id);
442 client2.disconnect(client1_conn_id);
443
444 async fn handle_messages(
445 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
446 peer: Arc<Peer>,
447 ) -> Result<()> {
448 while let Some(envelope) = messages.next().await {
449 let envelope = envelope.into_any();
450 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
451 let receipt = envelope.receipt();
452 peer.respond(receipt, proto::Ack {})?
453 } else if let Some(envelope) =
454 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
455 {
456 let message = &envelope.payload;
457 let receipt = envelope.receipt();
458 let response = match message.path.as_str() {
459 "path/one" => {
460 assert_eq!(message.worktree_id, 1);
461 proto::OpenBufferResponse {
462 buffer: Some(proto::Buffer {
463 variant: Some(proto::buffer::Variant::Id(0)),
464 }),
465 }
466 }
467 "path/two" => {
468 assert_eq!(message.worktree_id, 2);
469 proto::OpenBufferResponse {
470 buffer: Some(proto::Buffer {
471 variant: Some(proto::buffer::Variant::Id(1)),
472 }),
473 }
474 }
475 _ => {
476 panic!("unexpected path {}", message.path);
477 }
478 };
479
480 peer.respond(receipt, response)?
481 } else {
482 panic!("unknown message type");
483 }
484 }
485
486 Ok(())
487 }
488 }
489
490 #[gpui::test(iterations = 50)]
491 async fn test_order_of_response_and_incoming(cx: TestAppContext) {
492 let executor = cx.foreground();
493 let server = Peer::new();
494 let client = Peer::new();
495
496 let (client_to_server_conn, server_to_client_conn, _) =
497 Connection::in_memory(cx.background());
498 let (client_to_server_conn_id, io_task1, mut client_incoming) =
499 client.add_connection(client_to_server_conn).await;
500 let (server_to_client_conn_id, io_task2, mut server_incoming) =
501 server.add_connection(server_to_client_conn).await;
502
503 executor.spawn(io_task1).detach();
504 executor.spawn(io_task2).detach();
505
506 executor
507 .spawn(async move {
508 let request = server_incoming
509 .next()
510 .await
511 .unwrap()
512 .into_any()
513 .downcast::<TypedEnvelope<proto::Ping>>()
514 .unwrap();
515
516 server
517 .send(
518 server_to_client_conn_id,
519 proto::Error {
520 message: "message 1".to_string(),
521 },
522 )
523 .unwrap();
524 server
525 .send(
526 server_to_client_conn_id,
527 proto::Error {
528 message: "message 2".to_string(),
529 },
530 )
531 .unwrap();
532 server.respond(request.receipt(), proto::Ack {}).unwrap();
533
534 // Prevent the connection from being dropped
535 server_incoming.next().await;
536 })
537 .detach();
538
539 let events = Arc::new(Mutex::new(Vec::new()));
540
541 let response = client.request(client_to_server_conn_id, proto::Ping {});
542 let response_task = executor.spawn({
543 let events = events.clone();
544 async move {
545 response.await.unwrap();
546 events.lock().push("response".to_string());
547 }
548 });
549
550 executor
551 .spawn({
552 let events = events.clone();
553 async move {
554 let incoming1 = client_incoming
555 .next()
556 .await
557 .unwrap()
558 .into_any()
559 .downcast::<TypedEnvelope<proto::Error>>()
560 .unwrap();
561 events.lock().push(incoming1.payload.message);
562 let incoming2 = client_incoming
563 .next()
564 .await
565 .unwrap()
566 .into_any()
567 .downcast::<TypedEnvelope<proto::Error>>()
568 .unwrap();
569 events.lock().push(incoming2.payload.message);
570
571 // Prevent the connection from being dropped
572 client_incoming.next().await;
573 }
574 })
575 .detach();
576
577 response_task.await;
578 assert_eq!(
579 &*events.lock(),
580 &[
581 "message 1".to_string(),
582 "message 2".to_string(),
583 "response".to_string()
584 ]
585 );
586 }
587
588 #[gpui::test(iterations = 50)]
589 async fn test_dropping_request_before_completion(cx: TestAppContext) {
590 let executor = cx.foreground();
591 let server = Peer::new();
592 let client = Peer::new();
593
594 let (client_to_server_conn, server_to_client_conn, _) =
595 Connection::in_memory(cx.background());
596 let (client_to_server_conn_id, io_task1, mut client_incoming) =
597 client.add_connection(client_to_server_conn).await;
598 let (server_to_client_conn_id, io_task2, mut server_incoming) =
599 server.add_connection(server_to_client_conn).await;
600
601 executor.spawn(io_task1).detach();
602 executor.spawn(io_task2).detach();
603
604 executor
605 .spawn(async move {
606 let request1 = server_incoming
607 .next()
608 .await
609 .unwrap()
610 .into_any()
611 .downcast::<TypedEnvelope<proto::Ping>>()
612 .unwrap();
613 let request2 = server_incoming
614 .next()
615 .await
616 .unwrap()
617 .into_any()
618 .downcast::<TypedEnvelope<proto::Ping>>()
619 .unwrap();
620
621 server
622 .send(
623 server_to_client_conn_id,
624 proto::Error {
625 message: "message 1".to_string(),
626 },
627 )
628 .unwrap();
629 server
630 .send(
631 server_to_client_conn_id,
632 proto::Error {
633 message: "message 2".to_string(),
634 },
635 )
636 .unwrap();
637 server.respond(request1.receipt(), proto::Ack {}).unwrap();
638 server.respond(request2.receipt(), proto::Ack {}).unwrap();
639
640 // Prevent the connection from being dropped
641 server_incoming.next().await;
642 })
643 .detach();
644
645 let events = Arc::new(Mutex::new(Vec::new()));
646
647 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
648 let request1_task = executor.spawn(request1);
649 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
650 let request2_task = executor.spawn({
651 let events = events.clone();
652 async move {
653 request2.await.unwrap();
654 events.lock().push("response 2".to_string());
655 }
656 });
657
658 executor
659 .spawn({
660 let events = events.clone();
661 async move {
662 let incoming1 = client_incoming
663 .next()
664 .await
665 .unwrap()
666 .into_any()
667 .downcast::<TypedEnvelope<proto::Error>>()
668 .unwrap();
669 events.lock().push(incoming1.payload.message);
670 let incoming2 = client_incoming
671 .next()
672 .await
673 .unwrap()
674 .into_any()
675 .downcast::<TypedEnvelope<proto::Error>>()
676 .unwrap();
677 events.lock().push(incoming2.payload.message);
678
679 // Prevent the connection from being dropped
680 client_incoming.next().await;
681 }
682 })
683 .detach();
684
685 // Allow the request to make some progress before dropping it.
686 cx.background().simulate_random_delay().await;
687 drop(request1_task);
688
689 request2_task.await;
690 assert_eq!(
691 &*events.lock(),
692 &[
693 "message 1".to_string(),
694 "message 2".to_string(),
695 "response 2".to_string()
696 ]
697 );
698 }
699
700 #[gpui::test(iterations = 50)]
701 async fn test_disconnect(cx: TestAppContext) {
702 let executor = cx.foreground();
703
704 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
705
706 let client = Peer::new();
707 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
708
709 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
710 executor
711 .spawn(async move {
712 io_handler.await.ok();
713 io_ended_tx.send(()).await.unwrap();
714 })
715 .detach();
716
717 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
718 executor
719 .spawn(async move {
720 incoming.next().await;
721 messages_ended_tx.send(()).await.unwrap();
722 })
723 .detach();
724
725 client.disconnect(connection_id);
726
727 io_ended_rx.recv().await;
728 messages_ended_rx.recv().await;
729 assert!(server_conn
730 .send(WebSocketMessage::Binary(vec![]))
731 .await
732 .is_err());
733 }
734
735 #[gpui::test(iterations = 50)]
736 async fn test_io_error(cx: TestAppContext) {
737 let executor = cx.foreground();
738 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
739
740 let client = Peer::new();
741 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
742 executor.spawn(io_handler).detach();
743 executor
744 .spawn(async move { incoming.next().await })
745 .detach();
746
747 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
748 let _request = server_conn.rx.next().await.unwrap().unwrap();
749
750 drop(server_conn);
751 assert_eq!(
752 response.await.unwrap_err().to_string(),
753 "connection was closed"
754 );
755 }
756}