1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::stream::BoxStream;
5use futures::{FutureExt as _, StreamExt};
6use parking_lot::{Mutex, RwLock};
7use postage::{
8 barrier, mpsc,
9 prelude::{Sink as _, Stream as _},
10};
11use smol_timeout::TimeoutExt as _;
12use std::sync::atomic::Ordering::SeqCst;
13use std::{
14 collections::HashMap,
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28impl fmt::Display for ConnectionId {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 self.0.fmt(f)
31 }
32}
33
34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
35pub struct PeerId(pub u32);
36
37impl fmt::Display for PeerId {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 self.0.fmt(f)
40 }
41}
42
43pub struct Receipt<T> {
44 pub sender_id: ConnectionId,
45 pub message_id: u32,
46 payload_type: PhantomData<T>,
47}
48
49impl<T> Clone for Receipt<T> {
50 fn clone(&self) -> Self {
51 Self {
52 sender_id: self.sender_id,
53 message_id: self.message_id,
54 payload_type: PhantomData,
55 }
56 }
57}
58
59impl<T> Copy for Receipt<T> {}
60
61pub struct TypedEnvelope<T> {
62 pub sender_id: ConnectionId,
63 pub original_sender_id: Option<PeerId>,
64 pub message_id: u32,
65 pub payload: T,
66}
67
68impl<T> TypedEnvelope<T> {
69 pub fn original_sender_id(&self) -> Result<PeerId> {
70 self.original_sender_id
71 .ok_or_else(|| anyhow!("missing original_sender_id"))
72 }
73}
74
75impl<T: RequestMessage> TypedEnvelope<T> {
76 pub fn receipt(&self) -> Receipt<T> {
77 Receipt {
78 sender_id: self.sender_id,
79 message_id: self.message_id,
80 payload_type: PhantomData,
81 }
82 }
83}
84
85pub struct Peer {
86 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
87 next_connection_id: AtomicU32,
88}
89
90#[derive(Clone)]
91pub struct ConnectionState {
92 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
93 next_message_id: Arc<AtomicU32>,
94 response_channels:
95 Arc<Mutex<Option<HashMap<u32, mpsc::Sender<(proto::Envelope, barrier::Sender)>>>>>,
96}
97
98const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
99
100impl Peer {
101 pub fn new() -> Arc<Self> {
102 Arc::new(Self {
103 connections: Default::default(),
104 next_connection_id: Default::default(),
105 })
106 }
107
108 pub async fn add_connection(
109 self: &Arc<Self>,
110 connection: Connection,
111 ) -> (
112 ConnectionId,
113 impl Future<Output = anyhow::Result<()>> + Send,
114 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
115 ) {
116 // For outgoing messages, use an unbounded channel so that application code
117 // can always send messages without yielding. For incoming messages, use a
118 // bounded channel so that other peers will receive backpressure if they send
119 // messages faster than this peer can process them.
120 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
121 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
122
123 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
124 let connection_state = ConnectionState {
125 outgoing_tx,
126 next_message_id: Default::default(),
127 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
128 };
129 let mut writer = MessageStream::new(connection.tx);
130 let mut reader = MessageStream::new(connection.rx);
131
132 let this = self.clone();
133 let response_channels = connection_state.response_channels.clone();
134 let handle_io = async move {
135 let result = 'outer: loop {
136 let read_message = reader.read_message().fuse();
137 futures::pin_mut!(read_message);
138 loop {
139 futures::select_biased! {
140 outgoing = outgoing_rx.next().fuse() => match outgoing {
141 Some(outgoing) => {
142 match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
143 None => break 'outer Err(anyhow!("timed out writing RPC message")),
144 Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
145 _ => {}
146 }
147 }
148 None => break 'outer Ok(()),
149 },
150 incoming = read_message => match incoming {
151 Ok(incoming) => {
152 if incoming_tx.send(incoming).await.is_err() {
153 break 'outer Ok(());
154 }
155 break;
156 }
157 Err(error) => {
158 break 'outer Err(error).context("received invalid RPC message")
159 }
160 },
161 }
162 }
163 };
164
165 response_channels.lock().take();
166 this.connections.write().remove(&connection_id);
167 result
168 };
169
170 let response_channels = connection_state.response_channels.clone();
171 self.connections
172 .write()
173 .insert(connection_id, connection_state);
174
175 let incoming_rx = incoming_rx.filter_map(move |incoming| {
176 let response_channels = response_channels.clone();
177 async move {
178 if let Some(responding_to) = incoming.responding_to {
179 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
180 if let Some(mut tx) = channel {
181 let mut requester_resumed = barrier::channel();
182 tx.send((incoming, requester_resumed.0)).await.ok();
183 // Drop response channel before awaiting on the barrier. This allows the
184 // barrier to get dropped even if the request's future is dropped before it
185 // has a chance to observe the response.
186 drop(tx);
187 requester_resumed.1.recv().await;
188 } else {
189 log::warn!("received RPC response to unknown request {}", responding_to);
190 }
191
192 None
193 } else {
194 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
195 Some(envelope)
196 } else {
197 log::error!("unable to construct a typed envelope");
198 None
199 }
200 }
201 }
202 });
203 (connection_id, handle_io, incoming_rx.boxed())
204 }
205
206 pub fn disconnect(&self, connection_id: ConnectionId) {
207 self.connections.write().remove(&connection_id);
208 }
209
210 pub fn reset(&self) {
211 self.connections.write().clear();
212 }
213
214 pub fn request<T: RequestMessage>(
215 &self,
216 receiver_id: ConnectionId,
217 request: T,
218 ) -> impl Future<Output = Result<T::Response>> {
219 self.request_internal(None, receiver_id, request)
220 }
221
222 pub fn forward_request<T: RequestMessage>(
223 &self,
224 sender_id: ConnectionId,
225 receiver_id: ConnectionId,
226 request: T,
227 ) -> impl Future<Output = Result<T::Response>> {
228 self.request_internal(Some(sender_id), receiver_id, request)
229 }
230
231 pub fn request_internal<T: RequestMessage>(
232 &self,
233 original_sender_id: Option<ConnectionId>,
234 receiver_id: ConnectionId,
235 request: T,
236 ) -> impl Future<Output = Result<T::Response>> {
237 let (tx, mut rx) = mpsc::channel(1);
238 let send = self.connection_state(receiver_id).and_then(|connection| {
239 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
240 connection
241 .response_channels
242 .lock()
243 .as_mut()
244 .ok_or_else(|| anyhow!("connection was closed"))?
245 .insert(message_id, tx);
246 connection
247 .outgoing_tx
248 .unbounded_send(request.into_envelope(
249 message_id,
250 None,
251 original_sender_id.map(|id| id.0),
252 ))
253 .map_err(|_| anyhow!("connection was closed"))?;
254 Ok(())
255 });
256 async move {
257 send?;
258 let (response, _barrier) = rx
259 .recv()
260 .await
261 .ok_or_else(|| anyhow!("connection was closed"))?;
262 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
263 Err(anyhow!("request failed").context(error.message.clone()))
264 } else {
265 T::Response::from_envelope(response)
266 .ok_or_else(|| anyhow!("received response of the wrong type"))
267 }
268 }
269 }
270
271 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
272 let connection = self.connection_state(receiver_id)?;
273 let message_id = connection
274 .next_message_id
275 .fetch_add(1, atomic::Ordering::SeqCst);
276 connection
277 .outgoing_tx
278 .unbounded_send(message.into_envelope(message_id, None, None))?;
279 Ok(())
280 }
281
282 pub fn forward_send<T: EnvelopedMessage>(
283 &self,
284 sender_id: ConnectionId,
285 receiver_id: ConnectionId,
286 message: T,
287 ) -> Result<()> {
288 let connection = self.connection_state(receiver_id)?;
289 let message_id = connection
290 .next_message_id
291 .fetch_add(1, atomic::Ordering::SeqCst);
292 connection
293 .outgoing_tx
294 .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
295 Ok(())
296 }
297
298 pub fn respond<T: RequestMessage>(
299 &self,
300 receipt: Receipt<T>,
301 response: T::Response,
302 ) -> Result<()> {
303 let connection = self.connection_state(receipt.sender_id)?;
304 let message_id = connection
305 .next_message_id
306 .fetch_add(1, atomic::Ordering::SeqCst);
307 connection
308 .outgoing_tx
309 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
310 Ok(())
311 }
312
313 pub fn respond_with_error<T: RequestMessage>(
314 &self,
315 receipt: Receipt<T>,
316 response: proto::Error,
317 ) -> Result<()> {
318 let connection = self.connection_state(receipt.sender_id)?;
319 let message_id = connection
320 .next_message_id
321 .fetch_add(1, atomic::Ordering::SeqCst);
322 connection
323 .outgoing_tx
324 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
325 Ok(())
326 }
327
328 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
329 let connections = self.connections.read();
330 let connection = connections
331 .get(&connection_id)
332 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
333 Ok(connection.clone())
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use crate::TypedEnvelope;
341 use async_tungstenite::tungstenite::Message as WebSocketMessage;
342 use gpui::TestAppContext;
343
344 #[gpui::test(iterations = 50)]
345 async fn test_request_response(cx: TestAppContext) {
346 let executor = cx.foreground();
347
348 // create 2 clients connected to 1 server
349 let server = Peer::new();
350 let client1 = Peer::new();
351 let client2 = Peer::new();
352
353 let (client1_to_server_conn, server_to_client_1_conn, _) =
354 Connection::in_memory(cx.background());
355 let (client1_conn_id, io_task1, client1_incoming) =
356 client1.add_connection(client1_to_server_conn).await;
357 let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
358
359 let (client2_to_server_conn, server_to_client_2_conn, _) =
360 Connection::in_memory(cx.background());
361 let (client2_conn_id, io_task3, client2_incoming) =
362 client2.add_connection(client2_to_server_conn).await;
363 let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
364
365 executor.spawn(io_task1).detach();
366 executor.spawn(io_task2).detach();
367 executor.spawn(io_task3).detach();
368 executor.spawn(io_task4).detach();
369 executor
370 .spawn(handle_messages(server_incoming1, server.clone()))
371 .detach();
372 executor
373 .spawn(handle_messages(client1_incoming, client1.clone()))
374 .detach();
375 executor
376 .spawn(handle_messages(server_incoming2, server.clone()))
377 .detach();
378 executor
379 .spawn(handle_messages(client2_incoming, client2.clone()))
380 .detach();
381
382 assert_eq!(
383 client1
384 .request(client1_conn_id, proto::Ping {},)
385 .await
386 .unwrap(),
387 proto::Ack {}
388 );
389
390 assert_eq!(
391 client2
392 .request(client2_conn_id, proto::Ping {},)
393 .await
394 .unwrap(),
395 proto::Ack {}
396 );
397
398 assert_eq!(
399 client1
400 .request(
401 client1_conn_id,
402 proto::OpenBuffer {
403 project_id: 0,
404 worktree_id: 1,
405 path: "path/one".to_string(),
406 },
407 )
408 .await
409 .unwrap(),
410 proto::OpenBufferResponse {
411 buffer: Some(proto::Buffer {
412 variant: Some(proto::buffer::Variant::Id(0))
413 }),
414 }
415 );
416
417 assert_eq!(
418 client2
419 .request(
420 client2_conn_id,
421 proto::OpenBuffer {
422 project_id: 0,
423 worktree_id: 2,
424 path: "path/two".to_string(),
425 },
426 )
427 .await
428 .unwrap(),
429 proto::OpenBufferResponse {
430 buffer: Some(proto::Buffer {
431 variant: Some(proto::buffer::Variant::Id(1))
432 })
433 }
434 );
435
436 client1.disconnect(client1_conn_id);
437 client2.disconnect(client1_conn_id);
438
439 async fn handle_messages(
440 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
441 peer: Arc<Peer>,
442 ) -> Result<()> {
443 while let Some(envelope) = messages.next().await {
444 let envelope = envelope.into_any();
445 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
446 let receipt = envelope.receipt();
447 peer.respond(receipt, proto::Ack {})?
448 } else if let Some(envelope) =
449 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
450 {
451 let message = &envelope.payload;
452 let receipt = envelope.receipt();
453 let response = match message.path.as_str() {
454 "path/one" => {
455 assert_eq!(message.worktree_id, 1);
456 proto::OpenBufferResponse {
457 buffer: Some(proto::Buffer {
458 variant: Some(proto::buffer::Variant::Id(0)),
459 }),
460 }
461 }
462 "path/two" => {
463 assert_eq!(message.worktree_id, 2);
464 proto::OpenBufferResponse {
465 buffer: Some(proto::Buffer {
466 variant: Some(proto::buffer::Variant::Id(1)),
467 }),
468 }
469 }
470 _ => {
471 panic!("unexpected path {}", message.path);
472 }
473 };
474
475 peer.respond(receipt, response)?
476 } else {
477 panic!("unknown message type");
478 }
479 }
480
481 Ok(())
482 }
483 }
484
485 #[gpui::test(iterations = 50)]
486 async fn test_order_of_response_and_incoming(cx: TestAppContext) {
487 let executor = cx.foreground();
488 let server = Peer::new();
489 let client = Peer::new();
490
491 let (client_to_server_conn, server_to_client_conn, _) =
492 Connection::in_memory(cx.background());
493 let (client_to_server_conn_id, io_task1, mut client_incoming) =
494 client.add_connection(client_to_server_conn).await;
495 let (server_to_client_conn_id, io_task2, mut server_incoming) =
496 server.add_connection(server_to_client_conn).await;
497
498 executor.spawn(io_task1).detach();
499 executor.spawn(io_task2).detach();
500
501 executor
502 .spawn(async move {
503 let request = server_incoming
504 .next()
505 .await
506 .unwrap()
507 .into_any()
508 .downcast::<TypedEnvelope<proto::Ping>>()
509 .unwrap();
510
511 server
512 .send(
513 server_to_client_conn_id,
514 proto::Error {
515 message: "message 1".to_string(),
516 },
517 )
518 .unwrap();
519 server
520 .send(
521 server_to_client_conn_id,
522 proto::Error {
523 message: "message 2".to_string(),
524 },
525 )
526 .unwrap();
527 server.respond(request.receipt(), proto::Ack {}).unwrap();
528
529 // Prevent the connection from being dropped
530 server_incoming.next().await;
531 })
532 .detach();
533
534 let events = Arc::new(Mutex::new(Vec::new()));
535
536 let response = client.request(client_to_server_conn_id, proto::Ping {});
537 let response_task = executor.spawn({
538 let events = events.clone();
539 async move {
540 response.await.unwrap();
541 events.lock().push("response".to_string());
542 }
543 });
544
545 executor
546 .spawn({
547 let events = events.clone();
548 async move {
549 let incoming1 = client_incoming
550 .next()
551 .await
552 .unwrap()
553 .into_any()
554 .downcast::<TypedEnvelope<proto::Error>>()
555 .unwrap();
556 events.lock().push(incoming1.payload.message);
557 let incoming2 = client_incoming
558 .next()
559 .await
560 .unwrap()
561 .into_any()
562 .downcast::<TypedEnvelope<proto::Error>>()
563 .unwrap();
564 events.lock().push(incoming2.payload.message);
565
566 // Prevent the connection from being dropped
567 client_incoming.next().await;
568 }
569 })
570 .detach();
571
572 response_task.await;
573 assert_eq!(
574 &*events.lock(),
575 &[
576 "message 1".to_string(),
577 "message 2".to_string(),
578 "response".to_string()
579 ]
580 );
581 }
582
583 #[gpui::test(iterations = 50)]
584 async fn test_dropping_request_before_completion(cx: TestAppContext) {
585 let executor = cx.foreground();
586 let server = Peer::new();
587 let client = Peer::new();
588
589 let (client_to_server_conn, server_to_client_conn, _) =
590 Connection::in_memory(cx.background());
591 let (client_to_server_conn_id, io_task1, mut client_incoming) =
592 client.add_connection(client_to_server_conn).await;
593 let (server_to_client_conn_id, io_task2, mut server_incoming) =
594 server.add_connection(server_to_client_conn).await;
595
596 executor.spawn(io_task1).detach();
597 executor.spawn(io_task2).detach();
598
599 executor
600 .spawn(async move {
601 let request1 = server_incoming
602 .next()
603 .await
604 .unwrap()
605 .into_any()
606 .downcast::<TypedEnvelope<proto::Ping>>()
607 .unwrap();
608 let request2 = server_incoming
609 .next()
610 .await
611 .unwrap()
612 .into_any()
613 .downcast::<TypedEnvelope<proto::Ping>>()
614 .unwrap();
615
616 server
617 .send(
618 server_to_client_conn_id,
619 proto::Error {
620 message: "message 1".to_string(),
621 },
622 )
623 .unwrap();
624 server
625 .send(
626 server_to_client_conn_id,
627 proto::Error {
628 message: "message 2".to_string(),
629 },
630 )
631 .unwrap();
632 server.respond(request1.receipt(), proto::Ack {}).unwrap();
633 server.respond(request2.receipt(), proto::Ack {}).unwrap();
634
635 // Prevent the connection from being dropped
636 server_incoming.next().await;
637 })
638 .detach();
639
640 let events = Arc::new(Mutex::new(Vec::new()));
641
642 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
643 let request1_task = executor.spawn(request1);
644 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
645 let request2_task = executor.spawn({
646 let events = events.clone();
647 async move {
648 request2.await.unwrap();
649 events.lock().push("response 2".to_string());
650 }
651 });
652
653 executor
654 .spawn({
655 let events = events.clone();
656 async move {
657 let incoming1 = client_incoming
658 .next()
659 .await
660 .unwrap()
661 .into_any()
662 .downcast::<TypedEnvelope<proto::Error>>()
663 .unwrap();
664 events.lock().push(incoming1.payload.message);
665 let incoming2 = client_incoming
666 .next()
667 .await
668 .unwrap()
669 .into_any()
670 .downcast::<TypedEnvelope<proto::Error>>()
671 .unwrap();
672 events.lock().push(incoming2.payload.message);
673
674 // Prevent the connection from being dropped
675 client_incoming.next().await;
676 }
677 })
678 .detach();
679
680 // Allow the request to make some progress before dropping it.
681 cx.background().simulate_random_delay().await;
682 drop(request1_task);
683
684 request2_task.await;
685 assert_eq!(
686 &*events.lock(),
687 &[
688 "message 1".to_string(),
689 "message 2".to_string(),
690 "response 2".to_string()
691 ]
692 );
693 }
694
695 #[gpui::test(iterations = 50)]
696 async fn test_disconnect(cx: TestAppContext) {
697 let executor = cx.foreground();
698
699 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
700
701 let client = Peer::new();
702 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
703
704 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
705 executor
706 .spawn(async move {
707 io_handler.await.ok();
708 io_ended_tx.send(()).await.unwrap();
709 })
710 .detach();
711
712 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
713 executor
714 .spawn(async move {
715 incoming.next().await;
716 messages_ended_tx.send(()).await.unwrap();
717 })
718 .detach();
719
720 client.disconnect(connection_id);
721
722 io_ended_rx.recv().await;
723 messages_ended_rx.recv().await;
724 assert!(server_conn
725 .send(WebSocketMessage::Binary(vec![]))
726 .await
727 .is_err());
728 }
729
730 #[gpui::test(iterations = 50)]
731 async fn test_io_error(cx: TestAppContext) {
732 let executor = cx.foreground();
733 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
734
735 let client = Peer::new();
736 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
737 executor.spawn(io_handler).detach();
738 executor
739 .spawn(async move { incoming.next().await })
740 .detach();
741
742 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
743 let _request = server_conn.rx.next().await.unwrap().unwrap();
744
745 drop(server_conn);
746 assert_eq!(
747 response.await.unwrap_err().to_string(),
748 "connection was closed"
749 );
750 }
751}