1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::stream::BoxStream;
5use futures::{FutureExt as _, StreamExt};
6use parking_lot::{Mutex, RwLock};
7use postage::{
8 barrier, mpsc,
9 prelude::{Sink as _, Stream as _},
10};
11use smol_timeout::TimeoutExt as _;
12use std::sync::atomic::Ordering::SeqCst;
13use std::{
14 collections::HashMap,
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28impl fmt::Display for ConnectionId {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 self.0.fmt(f)
31 }
32}
33
34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
35pub struct PeerId(pub u32);
36
37impl fmt::Display for PeerId {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 self.0.fmt(f)
40 }
41}
42
43pub struct Receipt<T> {
44 pub sender_id: ConnectionId,
45 pub message_id: u32,
46 payload_type: PhantomData<T>,
47}
48
49impl<T> Clone for Receipt<T> {
50 fn clone(&self) -> Self {
51 Self {
52 sender_id: self.sender_id,
53 message_id: self.message_id,
54 payload_type: PhantomData,
55 }
56 }
57}
58
59impl<T> Copy for Receipt<T> {}
60
61pub struct TypedEnvelope<T> {
62 pub sender_id: ConnectionId,
63 pub original_sender_id: Option<PeerId>,
64 pub message_id: u32,
65 pub payload: T,
66}
67
68impl<T> TypedEnvelope<T> {
69 pub fn original_sender_id(&self) -> Result<PeerId> {
70 self.original_sender_id
71 .ok_or_else(|| anyhow!("missing original_sender_id"))
72 }
73}
74
75impl<T: RequestMessage> TypedEnvelope<T> {
76 pub fn receipt(&self) -> Receipt<T> {
77 Receipt {
78 sender_id: self.sender_id,
79 message_id: self.message_id,
80 payload_type: PhantomData,
81 }
82 }
83}
84
85pub struct Peer {
86 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
87 next_connection_id: AtomicU32,
88}
89
90#[derive(Clone)]
91pub struct ConnectionState {
92 outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
93 next_message_id: Arc<AtomicU32>,
94 response_channels:
95 Arc<Mutex<Option<HashMap<u32, mpsc::Sender<(proto::Envelope, barrier::Sender)>>>>>,
96}
97
98const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
99
100impl Peer {
101 pub fn new() -> Arc<Self> {
102 Arc::new(Self {
103 connections: Default::default(),
104 next_connection_id: Default::default(),
105 })
106 }
107
108 pub async fn add_connection(
109 self: &Arc<Self>,
110 connection: Connection,
111 ) -> (
112 ConnectionId,
113 impl Future<Output = anyhow::Result<()>> + Send,
114 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
115 ) {
116 // For outgoing messages, use an unbounded channel so that application code
117 // can always send messages without yielding. For incoming messages, use a
118 // bounded channel so that other peers will receive backpressure if they send
119 // messages faster than this peer can process them.
120 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
121 let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
122
123 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
124 let connection_state = ConnectionState {
125 outgoing_tx,
126 next_message_id: Default::default(),
127 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
128 };
129 let mut writer = MessageStream::new(connection.tx);
130 let mut reader = MessageStream::new(connection.rx);
131
132 let this = self.clone();
133 let response_channels = connection_state.response_channels.clone();
134 let handle_io = async move {
135 let result = 'outer: loop {
136 let read_message = reader.read_message().fuse();
137 futures::pin_mut!(read_message);
138 loop {
139 futures::select_biased! {
140 outgoing = outgoing_rx.next().fuse() => match outgoing {
141 Some(outgoing) => {
142 match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
143 None => break 'outer Err(anyhow!("timed out writing RPC message")),
144 Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
145 _ => {}
146 }
147 }
148 None => break 'outer Ok(()),
149 },
150 incoming = read_message => match incoming {
151 Ok(incoming) => {
152 if incoming_tx.send(incoming).await.is_err() {
153 break 'outer Ok(());
154 }
155 break;
156 }
157 Err(error) => {
158 break 'outer Err(error).context("received invalid RPC message")
159 }
160 },
161 }
162 }
163 };
164
165 response_channels.lock().take();
166 this.connections.write().remove(&connection_id);
167 result
168 };
169
170 let response_channels = connection_state.response_channels.clone();
171 self.connections
172 .write()
173 .insert(connection_id, connection_state);
174
175 let incoming_rx = incoming_rx.filter_map(move |incoming| {
176 let response_channels = response_channels.clone();
177 async move {
178 if let Some(responding_to) = incoming.responding_to {
179 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
180 if let Some(mut tx) = channel {
181 let mut requester_resumed = barrier::channel();
182 tx.send((incoming, requester_resumed.0)).await.ok();
183 requester_resumed.1.recv().await;
184 } else {
185 log::warn!("received RPC response to unknown request {}", responding_to);
186 }
187
188 None
189 } else {
190 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
191 Some(envelope)
192 } else {
193 log::error!("unable to construct a typed envelope");
194 None
195 }
196 }
197 }
198 });
199 (connection_id, handle_io, incoming_rx.boxed())
200 }
201
202 pub fn disconnect(&self, connection_id: ConnectionId) {
203 self.connections.write().remove(&connection_id);
204 }
205
206 pub fn reset(&self) {
207 self.connections.write().clear();
208 }
209
210 pub fn request<T: RequestMessage>(
211 &self,
212 receiver_id: ConnectionId,
213 request: T,
214 ) -> impl Future<Output = Result<T::Response>> {
215 self.request_internal(None, receiver_id, request)
216 }
217
218 pub fn forward_request<T: RequestMessage>(
219 &self,
220 sender_id: ConnectionId,
221 receiver_id: ConnectionId,
222 request: T,
223 ) -> impl Future<Output = Result<T::Response>> {
224 self.request_internal(Some(sender_id), receiver_id, request)
225 }
226
227 pub fn request_internal<T: RequestMessage>(
228 &self,
229 original_sender_id: Option<ConnectionId>,
230 receiver_id: ConnectionId,
231 request: T,
232 ) -> impl Future<Output = Result<T::Response>> {
233 let (tx, mut rx) = mpsc::channel(1);
234 let send = self.connection_state(receiver_id).and_then(|connection| {
235 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
236 connection
237 .response_channels
238 .lock()
239 .as_mut()
240 .ok_or_else(|| anyhow!("connection was closed"))?
241 .insert(message_id, tx);
242 connection
243 .outgoing_tx
244 .unbounded_send(request.into_envelope(
245 message_id,
246 None,
247 original_sender_id.map(|id| id.0),
248 ))
249 .map_err(|_| anyhow!("connection was closed"))?;
250 Ok(())
251 });
252 async move {
253 send?;
254 let (response, _barrier) = rx
255 .recv()
256 .await
257 .ok_or_else(|| anyhow!("connection was closed"))?;
258 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
259 Err(anyhow!("request failed").context(error.message.clone()))
260 } else {
261 T::Response::from_envelope(response)
262 .ok_or_else(|| anyhow!("received response of the wrong type"))
263 }
264 }
265 }
266
267 pub fn send<T: EnvelopedMessage>(&self, receiver_id: ConnectionId, message: T) -> Result<()> {
268 let connection = self.connection_state(receiver_id)?;
269 let message_id = connection
270 .next_message_id
271 .fetch_add(1, atomic::Ordering::SeqCst);
272 connection
273 .outgoing_tx
274 .unbounded_send(message.into_envelope(message_id, None, None))?;
275 Ok(())
276 }
277
278 pub fn forward_send<T: EnvelopedMessage>(
279 &self,
280 sender_id: ConnectionId,
281 receiver_id: ConnectionId,
282 message: T,
283 ) -> Result<()> {
284 let connection = self.connection_state(receiver_id)?;
285 let message_id = connection
286 .next_message_id
287 .fetch_add(1, atomic::Ordering::SeqCst);
288 connection
289 .outgoing_tx
290 .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
291 Ok(())
292 }
293
294 pub fn respond<T: RequestMessage>(
295 &self,
296 receipt: Receipt<T>,
297 response: T::Response,
298 ) -> Result<()> {
299 let connection = self.connection_state(receipt.sender_id)?;
300 let message_id = connection
301 .next_message_id
302 .fetch_add(1, atomic::Ordering::SeqCst);
303 connection
304 .outgoing_tx
305 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
306 Ok(())
307 }
308
309 pub fn respond_with_error<T: RequestMessage>(
310 &self,
311 receipt: Receipt<T>,
312 response: proto::Error,
313 ) -> Result<()> {
314 let connection = self.connection_state(receipt.sender_id)?;
315 let message_id = connection
316 .next_message_id
317 .fetch_add(1, atomic::Ordering::SeqCst);
318 connection
319 .outgoing_tx
320 .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
321 Ok(())
322 }
323
324 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
325 let connections = self.connections.read();
326 let connection = connections
327 .get(&connection_id)
328 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
329 Ok(connection.clone())
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::TypedEnvelope;
337 use async_tungstenite::tungstenite::Message as WebSocketMessage;
338 use gpui::TestAppContext;
339
340 #[gpui::test(iterations = 10)]
341 async fn test_request_response(cx: TestAppContext) {
342 let executor = cx.foreground();
343
344 // create 2 clients connected to 1 server
345 let server = Peer::new();
346 let client1 = Peer::new();
347 let client2 = Peer::new();
348
349 let (client1_to_server_conn, server_to_client_1_conn, _) =
350 Connection::in_memory(cx.background());
351 let (client1_conn_id, io_task1, client1_incoming) =
352 client1.add_connection(client1_to_server_conn).await;
353 let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
354
355 let (client2_to_server_conn, server_to_client_2_conn, _) =
356 Connection::in_memory(cx.background());
357 let (client2_conn_id, io_task3, client2_incoming) =
358 client2.add_connection(client2_to_server_conn).await;
359 let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
360
361 executor.spawn(io_task1).detach();
362 executor.spawn(io_task2).detach();
363 executor.spawn(io_task3).detach();
364 executor.spawn(io_task4).detach();
365 executor
366 .spawn(handle_messages(server_incoming1, server.clone()))
367 .detach();
368 executor
369 .spawn(handle_messages(client1_incoming, client1.clone()))
370 .detach();
371 executor
372 .spawn(handle_messages(server_incoming2, server.clone()))
373 .detach();
374 executor
375 .spawn(handle_messages(client2_incoming, client2.clone()))
376 .detach();
377
378 assert_eq!(
379 client1
380 .request(client1_conn_id, proto::Ping {},)
381 .await
382 .unwrap(),
383 proto::Ack {}
384 );
385
386 assert_eq!(
387 client2
388 .request(client2_conn_id, proto::Ping {},)
389 .await
390 .unwrap(),
391 proto::Ack {}
392 );
393
394 assert_eq!(
395 client1
396 .request(
397 client1_conn_id,
398 proto::OpenBuffer {
399 project_id: 0,
400 worktree_id: 1,
401 path: "path/one".to_string(),
402 },
403 )
404 .await
405 .unwrap(),
406 proto::OpenBufferResponse {
407 buffer: Some(proto::Buffer {
408 variant: Some(proto::buffer::Variant::Id(0))
409 }),
410 }
411 );
412
413 assert_eq!(
414 client2
415 .request(
416 client2_conn_id,
417 proto::OpenBuffer {
418 project_id: 0,
419 worktree_id: 2,
420 path: "path/two".to_string(),
421 },
422 )
423 .await
424 .unwrap(),
425 proto::OpenBufferResponse {
426 buffer: Some(proto::Buffer {
427 variant: Some(proto::buffer::Variant::Id(1))
428 })
429 }
430 );
431
432 client1.disconnect(client1_conn_id);
433 client2.disconnect(client1_conn_id);
434
435 async fn handle_messages(
436 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
437 peer: Arc<Peer>,
438 ) -> Result<()> {
439 while let Some(envelope) = messages.next().await {
440 let envelope = envelope.into_any();
441 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
442 let receipt = envelope.receipt();
443 peer.respond(receipt, proto::Ack {})?
444 } else if let Some(envelope) =
445 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
446 {
447 let message = &envelope.payload;
448 let receipt = envelope.receipt();
449 let response = match message.path.as_str() {
450 "path/one" => {
451 assert_eq!(message.worktree_id, 1);
452 proto::OpenBufferResponse {
453 buffer: Some(proto::Buffer {
454 variant: Some(proto::buffer::Variant::Id(0)),
455 }),
456 }
457 }
458 "path/two" => {
459 assert_eq!(message.worktree_id, 2);
460 proto::OpenBufferResponse {
461 buffer: Some(proto::Buffer {
462 variant: Some(proto::buffer::Variant::Id(1)),
463 }),
464 }
465 }
466 _ => {
467 panic!("unexpected path {}", message.path);
468 }
469 };
470
471 peer.respond(receipt, response)?
472 } else {
473 panic!("unknown message type");
474 }
475 }
476
477 Ok(())
478 }
479 }
480
481 #[gpui::test(iterations = 10)]
482 async fn test_order_of_response_and_incoming(cx: TestAppContext) {
483 let executor = cx.foreground();
484 let server = Peer::new();
485 let client = Peer::new();
486
487 let (client_to_server_conn, server_to_client_conn, _) =
488 Connection::in_memory(cx.background());
489 let (client_to_server_conn_id, io_task1, mut client_incoming) =
490 client.add_connection(client_to_server_conn).await;
491 let (server_to_client_conn_id, io_task2, mut server_incoming) =
492 server.add_connection(server_to_client_conn).await;
493
494 executor.spawn(io_task1).detach();
495 executor.spawn(io_task2).detach();
496
497 executor
498 .spawn(async move {
499 let request = server_incoming
500 .next()
501 .await
502 .unwrap()
503 .into_any()
504 .downcast::<TypedEnvelope<proto::Ping>>()
505 .unwrap();
506
507 server
508 .send(
509 server_to_client_conn_id,
510 proto::Error {
511 message: "message 1".to_string(),
512 },
513 )
514 .unwrap();
515 server
516 .send(
517 server_to_client_conn_id,
518 proto::Error {
519 message: "message 2".to_string(),
520 },
521 )
522 .unwrap();
523 server.respond(request.receipt(), proto::Ack {}).unwrap();
524
525 // Prevent the connection from being dropped
526 server_incoming.next().await;
527 })
528 .detach();
529
530 let events = Arc::new(Mutex::new(Vec::new()));
531
532 let response = client.request(client_to_server_conn_id, proto::Ping {});
533 let response_task = executor.spawn({
534 let events = events.clone();
535 async move {
536 response.await.unwrap();
537 events.lock().push("response".to_string());
538 }
539 });
540
541 executor
542 .spawn({
543 let events = events.clone();
544 async move {
545 let incoming1 = client_incoming
546 .next()
547 .await
548 .unwrap()
549 .into_any()
550 .downcast::<TypedEnvelope<proto::Error>>()
551 .unwrap();
552 events.lock().push(incoming1.payload.message);
553 let incoming2 = client_incoming
554 .next()
555 .await
556 .unwrap()
557 .into_any()
558 .downcast::<TypedEnvelope<proto::Error>>()
559 .unwrap();
560 events.lock().push(incoming2.payload.message);
561
562 // Prevent the connection from being dropped
563 client_incoming.next().await;
564 }
565 })
566 .detach();
567
568 response_task.await;
569 assert_eq!(
570 &*events.lock(),
571 &[
572 "message 1".to_string(),
573 "message 2".to_string(),
574 "response".to_string()
575 ]
576 );
577 }
578
579 #[gpui::test(iterations = 10)]
580 async fn test_disconnect(cx: TestAppContext) {
581 let executor = cx.foreground();
582
583 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
584
585 let client = Peer::new();
586 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
587
588 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
589 executor
590 .spawn(async move {
591 io_handler.await.ok();
592 io_ended_tx.send(()).await.unwrap();
593 })
594 .detach();
595
596 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
597 executor
598 .spawn(async move {
599 incoming.next().await;
600 messages_ended_tx.send(()).await.unwrap();
601 })
602 .detach();
603
604 client.disconnect(connection_id);
605
606 io_ended_rx.recv().await;
607 messages_ended_rx.recv().await;
608 assert!(server_conn
609 .send(WebSocketMessage::Binary(vec![]))
610 .await
611 .is_err());
612 }
613
614 #[gpui::test(iterations = 10)]
615 async fn test_io_error(cx: TestAppContext) {
616 let executor = cx.foreground();
617 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
618
619 let client = Peer::new();
620 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
621 executor.spawn(io_handler).detach();
622 executor
623 .spawn(async move { incoming.next().await })
624 .detach();
625
626 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
627 let _request = server_conn.rx.next().await.unwrap().unwrap();
628
629 drop(server_conn);
630 assert_eq!(
631 response.await.unwrap_err().to_string(),
632 "connection was closed"
633 );
634 }
635}