1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::stream::BoxStream;
5use futures::{FutureExt as _, StreamExt};
6use parking_lot::{Mutex, RwLock};
7use postage::{
8 mpsc,
9 prelude::{Sink as _, Stream as _},
10};
11use smol_timeout::TimeoutExt as _;
12use std::sync::atomic::Ordering::SeqCst;
13use std::{
14 collections::HashMap,
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28impl fmt::Display for ConnectionId {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 self.0.fmt(f)
31 }
32}
33
34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
35pub struct PeerId(pub u32);
36
37impl fmt::Display for PeerId {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 self.0.fmt(f)
40 }
41}
42
43pub struct Receipt<T> {
44 pub sender_id: ConnectionId,
45 pub message_id: u32,
46 payload_type: PhantomData<T>,
47}
48
49impl<T> Clone for Receipt<T> {
50 fn clone(&self) -> Self {
51 Self {
52 sender_id: self.sender_id,
53 message_id: self.message_id,
54 payload_type: PhantomData,
55 }
56 }
57}
58
59impl<T> Copy for Receipt<T> {}
60
61pub struct TypedEnvelope<T> {
62 pub sender_id: ConnectionId,
63 pub original_sender_id: Option<PeerId>,
64 pub message_id: u32,
65 pub payload: T,
66}
67
68impl<T> TypedEnvelope<T> {
69 pub fn original_sender_id(&self) -> Result<PeerId> {
70 self.original_sender_id
71 .ok_or_else(|| anyhow!("missing original_sender_id"))
72 }
73}
74
75impl<T: RequestMessage> TypedEnvelope<T> {
76 pub fn receipt(&self) -> Receipt<T> {
77 Receipt {
78 sender_id: self.sender_id,
79 message_id: self.message_id,
80 payload_type: PhantomData,
81 }
82 }
83}
84
85pub struct Peer {
86 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
87 next_connection_id: AtomicU32,
88}
89
90#[derive(Clone)]
91pub struct ConnectionState {
92 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
93 next_message_id: Arc<AtomicU32>,
94 response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
95}
96
97const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
98
99impl Peer {
100 pub fn new() -> Arc<Self> {
101 Arc::new(Self {
102 connections: Default::default(),
103 next_connection_id: Default::default(),
104 })
105 }
106
107 pub async fn add_connection(
108 self: &Arc<Self>,
109 connection: Connection,
110 ) -> (
111 ConnectionId,
112 impl Future<Output = anyhow::Result<()>> + Send,
113 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
114 ) {
115 // For outgoing messages, use an unbounded channel so that application code
116 // can always send messages without yielding. For incoming messages, use a
117 // bounded channel so that other peers will receive backpressure if they send
118 // messages faster than this peer can process them.
119 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
120 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
121
122 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
123 let connection_state = ConnectionState {
124 outgoing_tx,
125 next_message_id: Default::default(),
126 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
127 };
128 let mut writer = MessageStream::new(connection.tx);
129 let mut reader = MessageStream::new(connection.rx);
130
131 let this = self.clone();
132 let response_channels = connection_state.response_channels.clone();
133 let handle_io = async move {
134 let result = 'outer: loop {
135 let read_message = reader.read_message().fuse();
136 futures::pin_mut!(read_message);
137 loop {
138 futures::select_biased! {
139 outgoing = outgoing_rx.next().fuse() => match outgoing {
140 Some(outgoing) => {
141 match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
142 None => break 'outer Err(anyhow!("timed out writing RPC message")),
143 Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
144 _ => {}
145 }
146 }
147 None => break 'outer Ok(()),
148 },
149 incoming = read_message => match incoming {
150 Ok(incoming) => {
151 if incoming_tx.send(incoming).await.is_err() {
152 break 'outer Ok(());
153 }
154 break;
155 }
156 Err(error) => {
157 break 'outer Err(error).context("received invalid RPC message")
158 }
159 },
160 }
161 }
162 };
163
164 response_channels.lock().take();
165 this.connections.write().remove(&connection_id);
166 result
167 };
168
169 let response_channels = connection_state.response_channels.clone();
170 self.connections
171 .write()
172 .insert(connection_id, connection_state);
173
174 let incoming_rx = incoming_rx.filter_map(move |incoming| {
175 let response_channels = response_channels.clone();
176 async move {
177 if let Some(responding_to) = incoming.responding_to {
178 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
179 if let Some(mut tx) = channel {
180 if let Err(error) = tx.send(incoming).await {
181 log::debug!(
182 "received RPC but request future was dropped {:?}",
183 error.0
184 );
185 }
186 } else {
187 log::warn!("received RPC response to unknown request {}", responding_to);
188 }
189
190 None
191 } else {
192 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
193 Some(envelope)
194 } else {
195 log::error!("unable to construct a typed envelope");
196 None
197 }
198 }
199 }
200 });
201 (connection_id, handle_io, incoming_rx.boxed())
202 }
203
204 pub fn disconnect(&self, connection_id: ConnectionId) {
205 self.connections.write().remove(&connection_id);
206 }
207
208 pub fn reset(&self) {
209 self.connections.write().clear();
210 }
211
212 pub fn request<T: RequestMessage>(
213 &self,
214 receiver_id: ConnectionId,
215 request: T,
216 ) -> impl Future<Output = Result<T::Response>> {
217 self.request_internal(None, receiver_id, request)
218 }
219
220 pub fn forward_request<T: RequestMessage>(
221 &self,
222 sender_id: ConnectionId,
223 receiver_id: ConnectionId,
224 request: T,
225 ) -> impl Future<Output = Result<T::Response>> {
226 self.request_internal(Some(sender_id), receiver_id, request)
227 }
228
229 pub fn request_internal<T: RequestMessage>(
230 &self,
231 original_sender_id: Option<ConnectionId>,
232 receiver_id: ConnectionId,
233 request: T,
234 ) -> impl Future<Output = Result<T::Response>> {
235 let (tx, mut rx) = mpsc::channel(1);
236 let send = self.connection_state(receiver_id).and_then(|connection| {
237 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
238 connection
239 .response_channels
240 .lock()
241 .as_mut()
242 .ok_or_else(|| anyhow!("connection was closed"))?
243 .insert(message_id, tx);
244 connection
245 .outgoing_tx
246 .unbounded_send(request.into_envelope(
247 message_id,
248 None,
249 original_sender_id.map(|id| id.0),
250 ))
251 .map_err(|_| anyhow!("connection was closed"))?;
252 Ok(())
253 });
254 async move {
255 send?;
256 let response = rx
257 .recv()
258 .await
259 .ok_or_else(|| anyhow!("connection was closed"))?;
260 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
261 Err(anyhow!("request failed").context(error.message.clone()))
262 } else {
263 T::Response::from_envelope(response)
264 .ok_or_else(|| anyhow!("received response of the wrong type"))
265 }
266 }
267 }
268
269 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
270 let connection = self.connection_state(receiver_id)?;
271 let message_id = connection
272 .next_message_id
273 .fetch_add(1, atomic::Ordering::SeqCst);
274 connection
275 .outgoing_tx
276 .unbounded_send(message.into_envelope(message_id, None, None))?;
277 Ok(())
278 }
279
280 pub fn forward_send<T: EnvelopedMessage>(
281 &self,
282 sender_id: ConnectionId,
283 receiver_id: ConnectionId,
284 message: T,
285 ) -> Result<()> {
286 let connection = self.connection_state(receiver_id)?;
287 let message_id = connection
288 .next_message_id
289 .fetch_add(1, atomic::Ordering::SeqCst);
290 connection
291 .outgoing_tx
292 .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
293 Ok(())
294 }
295
296 pub fn respond<T: RequestMessage>(
297 &self,
298 receipt: Receipt<T>,
299 response: T::Response,
300 ) -> Result<()> {
301 let connection = self.connection_state(receipt.sender_id)?;
302 let message_id = connection
303 .next_message_id
304 .fetch_add(1, atomic::Ordering::SeqCst);
305 connection
306 .outgoing_tx
307 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
308 Ok(())
309 }
310
311 pub fn respond_with_error<T: RequestMessage>(
312 &self,
313 receipt: Receipt<T>,
314 response: proto::Error,
315 ) -> Result<()> {
316 let connection = self.connection_state(receipt.sender_id)?;
317 let message_id = connection
318 .next_message_id
319 .fetch_add(1, atomic::Ordering::SeqCst);
320 connection
321 .outgoing_tx
322 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
323 Ok(())
324 }
325
326 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
327 let connections = self.connections.read();
328 let connection = connections
329 .get(&connection_id)
330 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
331 Ok(connection.clone())
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use crate::TypedEnvelope;
339 use async_tungstenite::tungstenite::Message as WebSocketMessage;
340 use gpui::TestAppContext;
341
342 #[gpui::test(iterations = 50)]
343 async fn test_request_response(cx: TestAppContext) {
344 let executor = cx.foreground();
345
346 // create 2 clients connected to 1 server
347 let server = Peer::new();
348 let client1 = Peer::new();
349 let client2 = Peer::new();
350
351 let (client1_to_server_conn, server_to_client_1_conn, _) =
352 Connection::in_memory(cx.background());
353 let (client1_conn_id, io_task1, client1_incoming) =
354 client1.add_connection(client1_to_server_conn).await;
355 let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
356
357 let (client2_to_server_conn, server_to_client_2_conn, _) =
358 Connection::in_memory(cx.background());
359 let (client2_conn_id, io_task3, client2_incoming) =
360 client2.add_connection(client2_to_server_conn).await;
361 let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
362
363 executor.spawn(io_task1).detach();
364 executor.spawn(io_task2).detach();
365 executor.spawn(io_task3).detach();
366 executor.spawn(io_task4).detach();
367 executor
368 .spawn(handle_messages(server_incoming1, server.clone()))
369 .detach();
370 executor
371 .spawn(handle_messages(client1_incoming, client1.clone()))
372 .detach();
373 executor
374 .spawn(handle_messages(server_incoming2, server.clone()))
375 .detach();
376 executor
377 .spawn(handle_messages(client2_incoming, client2.clone()))
378 .detach();
379
380 assert_eq!(
381 client1
382 .request(client1_conn_id, proto::Ping {},)
383 .await
384 .unwrap(),
385 proto::Ack {}
386 );
387
388 assert_eq!(
389 client2
390 .request(client2_conn_id, proto::Ping {},)
391 .await
392 .unwrap(),
393 proto::Ack {}
394 );
395
396 assert_eq!(
397 client1
398 .request(
399 client1_conn_id,
400 proto::OpenBuffer {
401 project_id: 0,
402 worktree_id: 1,
403 path: "path/one".to_string(),
404 },
405 )
406 .await
407 .unwrap(),
408 proto::OpenBufferResponse {
409 buffer: Some(proto::Buffer {
410 variant: Some(proto::buffer::Variant::Id(0))
411 }),
412 }
413 );
414
415 assert_eq!(
416 client2
417 .request(
418 client2_conn_id,
419 proto::OpenBuffer {
420 project_id: 0,
421 worktree_id: 2,
422 path: "path/two".to_string(),
423 },
424 )
425 .await
426 .unwrap(),
427 proto::OpenBufferResponse {
428 buffer: Some(proto::Buffer {
429 variant: Some(proto::buffer::Variant::Id(1))
430 })
431 }
432 );
433
434 client1.disconnect(client1_conn_id);
435 client2.disconnect(client1_conn_id);
436
437 async fn handle_messages(
438 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
439 peer: Arc<Peer>,
440 ) -> Result<()> {
441 while let Some(envelope) = messages.next().await {
442 let envelope = envelope.into_any();
443 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
444 let receipt = envelope.receipt();
445 peer.respond(receipt, proto::Ack {})?
446 } else if let Some(envelope) =
447 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
448 {
449 let message = &envelope.payload;
450 let receipt = envelope.receipt();
451 let response = match message.path.as_str() {
452 "path/one" => {
453 assert_eq!(message.worktree_id, 1);
454 proto::OpenBufferResponse {
455 buffer: Some(proto::Buffer {
456 variant: Some(proto::buffer::Variant::Id(0)),
457 }),
458 }
459 }
460 "path/two" => {
461 assert_eq!(message.worktree_id, 2);
462 proto::OpenBufferResponse {
463 buffer: Some(proto::Buffer {
464 variant: Some(proto::buffer::Variant::Id(1)),
465 }),
466 }
467 }
468 _ => {
469 panic!("unexpected path {}", message.path);
470 }
471 };
472
473 peer.respond(receipt, response)?
474 } else {
475 panic!("unknown message type");
476 }
477 }
478
479 Ok(())
480 }
481 }
482
483 #[gpui::test(iterations = 50)]
484 async fn test_order_of_response_and_incoming(cx: TestAppContext) {
485 let executor = cx.foreground();
486 let server = Peer::new();
487 let client = Peer::new();
488
489 let (client_to_server_conn, server_to_client_conn, _) =
490 Connection::in_memory(cx.background());
491 let (client_to_server_conn_id, io_task1, mut client_incoming) =
492 client.add_connection(client_to_server_conn).await;
493 let (server_to_client_conn_id, io_task2, mut server_incoming) =
494 server.add_connection(server_to_client_conn).await;
495
496 executor.spawn(io_task1).detach();
497 executor.spawn(io_task2).detach();
498
499 executor
500 .spawn(async move {
501 let request = server_incoming
502 .next()
503 .await
504 .unwrap()
505 .into_any()
506 .downcast::<TypedEnvelope<proto::Ping>>()
507 .unwrap();
508
509 server
510 .send(
511 server_to_client_conn_id,
512 proto::Error {
513 message: "message 1".to_string(),
514 },
515 )
516 .unwrap();
517 server
518 .send(
519 server_to_client_conn_id,
520 proto::Error {
521 message: "message 2".to_string(),
522 },
523 )
524 .unwrap();
525 server.respond(request.receipt(), proto::Ack {}).unwrap();
526
527 // Prevent the connection from being dropped
528 server_incoming.next().await;
529 })
530 .detach();
531
532 let events = Arc::new(Mutex::new(Vec::new()));
533
534 let response = client.request(client_to_server_conn_id, proto::Ping {});
535 let response_task = executor.spawn({
536 let events = events.clone();
537 async move {
538 response.await.unwrap();
539 events.lock().push("response".to_string());
540 }
541 });
542
543 executor
544 .spawn({
545 let events = events.clone();
546 async move {
547 let incoming1 = client_incoming
548 .next()
549 .await
550 .unwrap()
551 .into_any()
552 .downcast::<TypedEnvelope<proto::Error>>()
553 .unwrap();
554 events.lock().push(incoming1.payload.message);
555 let incoming2 = client_incoming
556 .next()
557 .await
558 .unwrap()
559 .into_any()
560 .downcast::<TypedEnvelope<proto::Error>>()
561 .unwrap();
562 events.lock().push(incoming2.payload.message);
563
564 // Prevent the connection from being dropped
565 client_incoming.next().await;
566 }
567 })
568 .detach();
569
570 response_task.await;
571 assert_eq!(
572 &*events.lock(),
573 &[
574 "message 1".to_string(),
575 "message 2".to_string(),
576 "response".to_string()
577 ]
578 );
579 }
580
581 #[gpui::test(iterations = 50)]
582 async fn test_dropping_request_before_completion(cx: TestAppContext) {
583 let executor = cx.foreground();
584 let server = Peer::new();
585 let client = Peer::new();
586
587 let (client_to_server_conn, server_to_client_conn, _) =
588 Connection::in_memory(cx.background());
589 let (client_to_server_conn_id, io_task1, mut client_incoming) =
590 client.add_connection(client_to_server_conn).await;
591 let (server_to_client_conn_id, io_task2, mut server_incoming) =
592 server.add_connection(server_to_client_conn).await;
593
594 executor.spawn(io_task1).detach();
595 executor.spawn(io_task2).detach();
596
597 executor
598 .spawn(async move {
599 let request1 = server_incoming
600 .next()
601 .await
602 .unwrap()
603 .into_any()
604 .downcast::<TypedEnvelope<proto::Ping>>()
605 .unwrap();
606 let request2 = server_incoming
607 .next()
608 .await
609 .unwrap()
610 .into_any()
611 .downcast::<TypedEnvelope<proto::Ping>>()
612 .unwrap();
613
614 server
615 .send(
616 server_to_client_conn_id,
617 proto::Error {
618 message: "message 1".to_string(),
619 },
620 )
621 .unwrap();
622 server
623 .send(
624 server_to_client_conn_id,
625 proto::Error {
626 message: "message 2".to_string(),
627 },
628 )
629 .unwrap();
630 server.respond(request1.receipt(), proto::Ack {}).unwrap();
631 server.respond(request2.receipt(), proto::Ack {}).unwrap();
632
633 // Prevent the connection from being dropped
634 server_incoming.next().await;
635 })
636 .detach();
637
638 let events = Arc::new(Mutex::new(Vec::new()));
639
640 let request1 = client.request(client_to_server_conn_id, proto::Ping {});
641 let request1_task = executor.spawn(request1);
642 let request2 = client.request(client_to_server_conn_id, proto::Ping {});
643 let request2_task = executor.spawn({
644 let events = events.clone();
645 async move {
646 request2.await.unwrap();
647 events.lock().push("response 2".to_string());
648 }
649 });
650
651 executor
652 .spawn({
653 let events = events.clone();
654 async move {
655 let incoming1 = client_incoming
656 .next()
657 .await
658 .unwrap()
659 .into_any()
660 .downcast::<TypedEnvelope<proto::Error>>()
661 .unwrap();
662 events.lock().push(incoming1.payload.message);
663 let incoming2 = client_incoming
664 .next()
665 .await
666 .unwrap()
667 .into_any()
668 .downcast::<TypedEnvelope<proto::Error>>()
669 .unwrap();
670 events.lock().push(incoming2.payload.message);
671
672 // Prevent the connection from being dropped
673 client_incoming.next().await;
674 }
675 })
676 .detach();
677
678 // Allow the request to make some progress before dropping it.
679 cx.background().simulate_random_delay().await;
680 drop(request1_task);
681
682 request2_task.await;
683 assert_eq!(
684 &*events.lock(),
685 &[
686 "message 1".to_string(),
687 "message 2".to_string(),
688 "response 2".to_string()
689 ]
690 );
691 }
692
693 #[gpui::test(iterations = 50)]
694 async fn test_disconnect(cx: TestAppContext) {
695 let executor = cx.foreground();
696
697 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
698
699 let client = Peer::new();
700 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
701
702 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
703 executor
704 .spawn(async move {
705 io_handler.await.ok();
706 io_ended_tx.send(()).await.unwrap();
707 })
708 .detach();
709
710 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
711 executor
712 .spawn(async move {
713 incoming.next().await;
714 messages_ended_tx.send(()).await.unwrap();
715 })
716 .detach();
717
718 client.disconnect(connection_id);
719
720 io_ended_rx.recv().await;
721 messages_ended_rx.recv().await;
722 assert!(server_conn
723 .send(WebSocketMessage::Binary(vec![]))
724 .await
725 .is_err());
726 }
727
728 #[gpui::test(iterations = 50)]
729 async fn test_io_error(cx: TestAppContext) {
730 let executor = cx.foreground();
731 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
732
733 let client = Peer::new();
734 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
735 executor.spawn(io_handler).detach();
736 executor
737 .spawn(async move { incoming.next().await })
738 .detach();
739
740 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
741 let _request = server_conn.rx.next().await.unwrap().unwrap();
742
743 drop(server_conn);
744 assert_eq!(
745 response.await.unwrap_err().to_string(),
746 "connection was closed"
747 );
748 }
749}