1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
2use super::Connection;
3use anyhow::{anyhow, Context, Result};
4use futures::stream::BoxStream;
5use futures::{FutureExt as _, StreamExt};
6use parking_lot::{Mutex, RwLock};
7use postage::{
8 mpsc,
9 prelude::{Sink as _, Stream as _},
10};
11use smol_timeout::TimeoutExt as _;
12use std::sync::atomic::Ordering::SeqCst;
13use std::{
14 collections::HashMap,
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 sync::{
19 atomic::{self, AtomicU32},
20 Arc,
21 },
22 time::Duration,
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28impl fmt::Display for ConnectionId {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 self.0.fmt(f)
31 }
32}
33
34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
35pub struct PeerId(pub u32);
36
37impl fmt::Display for PeerId {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 self.0.fmt(f)
40 }
41}
42
43pub struct Receipt<T> {
44 pub sender_id: ConnectionId,
45 pub message_id: u32,
46 payload_type: PhantomData<T>,
47}
48
49impl<T> Clone for Receipt<T> {
50 fn clone(&self) -> Self {
51 Self {
52 sender_id: self.sender_id,
53 message_id: self.message_id,
54 payload_type: PhantomData,
55 }
56 }
57}
58
59impl<T> Copy for Receipt<T> {}
60
61pub struct TypedEnvelope<T> {
62 pub sender_id: ConnectionId,
63 pub original_sender_id: Option<PeerId>,
64 pub message_id: u32,
65 pub payload: T,
66}
67
68impl<T> TypedEnvelope<T> {
69 pub fn original_sender_id(&self) -> Result<PeerId> {
70 self.original_sender_id
71 .ok_or_else(|| anyhow!("missing original_sender_id"))
72 }
73}
74
75impl<T: RequestMessage> TypedEnvelope<T> {
76 pub fn receipt(&self) -> Receipt<T> {
77 Receipt {
78 sender_id: self.sender_id,
79 message_id: self.message_id,
80 payload_type: PhantomData,
81 }
82 }
83}
84
85pub struct Peer {
86 pub connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
87 next_connection_id: AtomicU32,
88}
89
90#[derive(Clone)]
91pub struct ConnectionState {
92 outgoing_tx: mpsc::Sender<proto::Envelope>,
93 next_message_id: Arc<AtomicU32>,
94 response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
95}
96
97const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
98
99impl Peer {
100 pub fn new() -> Arc<Self> {
101 Arc::new(Self {
102 connections: Default::default(),
103 next_connection_id: Default::default(),
104 })
105 }
106
107 pub async fn add_connection(
108 self: &Arc<Self>,
109 connection: Connection,
110 ) -> (
111 ConnectionId,
112 impl Future<Output = anyhow::Result<()>> + Send,
113 BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
114 ) {
115 let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
116 let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
117 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
118 let connection_state = ConnectionState {
119 outgoing_tx,
120 next_message_id: Default::default(),
121 response_channels: Arc::new(Mutex::new(Some(Default::default()))),
122 };
123 let mut writer = MessageStream::new(connection.tx);
124 let mut reader = MessageStream::new(connection.rx);
125
126 let this = self.clone();
127 let response_channels = connection_state.response_channels.clone();
128 let handle_io = async move {
129 let result = 'outer: loop {
130 let read_message = reader.read_message().fuse();
131 futures::pin_mut!(read_message);
132 loop {
133 futures::select_biased! {
134 incoming = read_message => match incoming {
135 Ok(incoming) => {
136 if incoming_tx.send(incoming).await.is_err() {
137 break 'outer Ok(());
138 }
139 break;
140 }
141 Err(error) => {
142 break 'outer Err(error).context("received invalid RPC message")
143 }
144 },
145 outgoing = outgoing_rx.recv().fuse() => match outgoing {
146 Some(outgoing) => {
147 match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
148 None => break 'outer Err(anyhow!("timed out writing RPC message")),
149 Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
150 _ => {}
151 }
152 }
153 None => break 'outer Ok(()),
154 }
155 }
156 }
157 };
158
159 response_channels.lock().take();
160 this.connections.write().remove(&connection_id);
161 result
162 };
163
164 let response_channels = connection_state.response_channels.clone();
165 self.connections
166 .write()
167 .insert(connection_id, connection_state);
168
169 let incoming_rx = incoming_rx.filter_map(move |incoming| {
170 let response_channels = response_channels.clone();
171 async move {
172 if let Some(responding_to) = incoming.responding_to {
173 let channel = response_channels.lock().as_mut()?.remove(&responding_to);
174 if let Some(mut tx) = channel {
175 tx.send(incoming).await.ok();
176 } else {
177 log::warn!("received RPC response to unknown request {}", responding_to);
178 }
179
180 None
181 } else {
182 if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
183 Some(envelope)
184 } else {
185 log::error!("unable to construct a typed envelope");
186 None
187 }
188 }
189 }
190 });
191 (connection_id, handle_io, incoming_rx.boxed())
192 }
193
194 pub fn disconnect(&self, connection_id: ConnectionId) {
195 self.connections.write().remove(&connection_id);
196 }
197
198 pub fn reset(&self) {
199 self.connections.write().clear();
200 }
201
202 pub fn request<T: RequestMessage>(
203 self: &Arc<Self>,
204 receiver_id: ConnectionId,
205 request: T,
206 ) -> impl Future<Output = Result<T::Response>> {
207 self.request_internal(None, receiver_id, request)
208 }
209
210 pub fn forward_request<T: RequestMessage>(
211 self: &Arc<Self>,
212 sender_id: ConnectionId,
213 receiver_id: ConnectionId,
214 request: T,
215 ) -> impl Future<Output = Result<T::Response>> {
216 self.request_internal(Some(sender_id), receiver_id, request)
217 }
218
219 pub fn request_internal<T: RequestMessage>(
220 self: &Arc<Self>,
221 original_sender_id: Option<ConnectionId>,
222 receiver_id: ConnectionId,
223 request: T,
224 ) -> impl Future<Output = Result<T::Response>> {
225 let this = self.clone();
226 let (tx, mut rx) = mpsc::channel(1);
227 async move {
228 let mut connection = this.connection_state(receiver_id)?;
229 let message_id = connection.next_message_id.fetch_add(1, SeqCst);
230 connection
231 .response_channels
232 .lock()
233 .as_mut()
234 .ok_or_else(|| anyhow!("connection was closed"))?
235 .insert(message_id, tx);
236 connection
237 .outgoing_tx
238 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
239 .await
240 .map_err(|_| anyhow!("connection was closed"))?;
241 let response = rx
242 .recv()
243 .await
244 .ok_or_else(|| anyhow!("connection was closed"))?;
245 if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
246 Err(anyhow!("request failed").context(error.message.clone()))
247 } else {
248 T::Response::from_envelope(response)
249 .ok_or_else(|| anyhow!("received response of the wrong type"))
250 }
251 }
252 }
253
254 pub fn send<T: EnvelopedMessage>(
255 self: &Arc<Self>,
256 receiver_id: ConnectionId,
257 message: T,
258 ) -> impl Future<Output = Result<()>> {
259 let this = self.clone();
260 async move {
261 let mut connection = this.connection_state(receiver_id)?;
262 let message_id = connection
263 .next_message_id
264 .fetch_add(1, atomic::Ordering::SeqCst);
265 connection
266 .outgoing_tx
267 .send(message.into_envelope(message_id, None, None))
268 .await?;
269 Ok(())
270 }
271 }
272
273 pub fn forward_send<T: EnvelopedMessage>(
274 self: &Arc<Self>,
275 sender_id: ConnectionId,
276 receiver_id: ConnectionId,
277 message: T,
278 ) -> impl Future<Output = Result<()>> {
279 let this = self.clone();
280 async move {
281 let mut connection = this.connection_state(receiver_id)?;
282 let message_id = connection
283 .next_message_id
284 .fetch_add(1, atomic::Ordering::SeqCst);
285 connection
286 .outgoing_tx
287 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
288 .await?;
289 Ok(())
290 }
291 }
292
293 pub fn respond<T: RequestMessage>(
294 self: &Arc<Self>,
295 receipt: Receipt<T>,
296 response: T::Response,
297 ) -> impl Future<Output = Result<()>> {
298 let this = self.clone();
299 async move {
300 let mut connection = this.connection_state(receipt.sender_id)?;
301 let message_id = connection
302 .next_message_id
303 .fetch_add(1, atomic::Ordering::SeqCst);
304 connection
305 .outgoing_tx
306 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
307 .await?;
308 Ok(())
309 }
310 }
311
312 pub fn respond_with_error<T: RequestMessage>(
313 self: &Arc<Self>,
314 receipt: Receipt<T>,
315 response: proto::Error,
316 ) -> impl Future<Output = Result<()>> {
317 let this = self.clone();
318 async move {
319 let mut connection = this.connection_state(receipt.sender_id)?;
320 let message_id = connection
321 .next_message_id
322 .fetch_add(1, atomic::Ordering::SeqCst);
323 connection
324 .outgoing_tx
325 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
326 .await?;
327 Ok(())
328 }
329 }
330
331 fn connection_state(&self, connection_id: ConnectionId) -> Result<ConnectionState> {
332 let connections = self.connections.read();
333 let connection = connections
334 .get(&connection_id)
335 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
336 Ok(connection.clone())
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use crate::TypedEnvelope;
344 use async_tungstenite::tungstenite::Message as WebSocketMessage;
345 use gpui::TestAppContext;
346
347 #[gpui::test(iterations = 10)]
348 async fn test_request_response(cx: TestAppContext) {
349 let executor = cx.foreground();
350
351 // create 2 clients connected to 1 server
352 let server = Peer::new();
353 let client1 = Peer::new();
354 let client2 = Peer::new();
355
356 let (client1_to_server_conn, server_to_client_1_conn, _) =
357 Connection::in_memory(cx.background());
358 let (client1_conn_id, io_task1, client1_incoming) =
359 client1.add_connection(client1_to_server_conn).await;
360 let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
361
362 let (client2_to_server_conn, server_to_client_2_conn, _) =
363 Connection::in_memory(cx.background());
364 let (client2_conn_id, io_task3, client2_incoming) =
365 client2.add_connection(client2_to_server_conn).await;
366 let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
367
368 executor.spawn(io_task1).detach();
369 executor.spawn(io_task2).detach();
370 executor.spawn(io_task3).detach();
371 executor.spawn(io_task4).detach();
372 executor
373 .spawn(handle_messages(server_incoming1, server.clone()))
374 .detach();
375 executor
376 .spawn(handle_messages(client1_incoming, client1.clone()))
377 .detach();
378 executor
379 .spawn(handle_messages(server_incoming2, server.clone()))
380 .detach();
381 executor
382 .spawn(handle_messages(client2_incoming, client2.clone()))
383 .detach();
384
385 assert_eq!(
386 client1
387 .request(client1_conn_id, proto::Ping {},)
388 .await
389 .unwrap(),
390 proto::Ack {}
391 );
392
393 assert_eq!(
394 client2
395 .request(client2_conn_id, proto::Ping {},)
396 .await
397 .unwrap(),
398 proto::Ack {}
399 );
400
401 assert_eq!(
402 client1
403 .request(
404 client1_conn_id,
405 proto::OpenBuffer {
406 project_id: 0,
407 worktree_id: 1,
408 path: "path/one".to_string(),
409 },
410 )
411 .await
412 .unwrap(),
413 proto::OpenBufferResponse {
414 buffer: Some(proto::Buffer {
415 variant: Some(proto::buffer::Variant::Id(0))
416 }),
417 }
418 );
419
420 assert_eq!(
421 client2
422 .request(
423 client2_conn_id,
424 proto::OpenBuffer {
425 project_id: 0,
426 worktree_id: 2,
427 path: "path/two".to_string(),
428 },
429 )
430 .await
431 .unwrap(),
432 proto::OpenBufferResponse {
433 buffer: Some(proto::Buffer {
434 variant: Some(proto::buffer::Variant::Id(1))
435 })
436 }
437 );
438
439 client1.disconnect(client1_conn_id);
440 client2.disconnect(client1_conn_id);
441
442 async fn handle_messages(
443 mut messages: BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
444 peer: Arc<Peer>,
445 ) -> Result<()> {
446 while let Some(envelope) = messages.next().await {
447 let envelope = envelope.into_any();
448 if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
449 let receipt = envelope.receipt();
450 peer.respond(receipt, proto::Ack {}).await?
451 } else if let Some(envelope) =
452 envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
453 {
454 let message = &envelope.payload;
455 let receipt = envelope.receipt();
456 let response = match message.path.as_str() {
457 "path/one" => {
458 assert_eq!(message.worktree_id, 1);
459 proto::OpenBufferResponse {
460 buffer: Some(proto::Buffer {
461 variant: Some(proto::buffer::Variant::Id(0)),
462 }),
463 }
464 }
465 "path/two" => {
466 assert_eq!(message.worktree_id, 2);
467 proto::OpenBufferResponse {
468 buffer: Some(proto::Buffer {
469 variant: Some(proto::buffer::Variant::Id(1)),
470 }),
471 }
472 }
473 _ => {
474 panic!("unexpected path {}", message.path);
475 }
476 };
477
478 peer.respond(receipt, response).await?
479 } else {
480 panic!("unknown message type");
481 }
482 }
483
484 Ok(())
485 }
486 }
487
488 #[gpui::test(iterations = 10)]
489 async fn test_order_of_response_and_incoming(cx: TestAppContext) {
490 let executor = cx.foreground();
491 let server = Peer::new();
492 let client = Peer::new();
493
494 let (client_to_server_conn, server_to_client_conn, _) =
495 Connection::in_memory(cx.background());
496 let (client_to_server_conn_id, io_task1, mut client_incoming) =
497 client.add_connection(client_to_server_conn).await;
498 let (server_to_client_conn_id, io_task2, mut server_incoming) =
499 server.add_connection(server_to_client_conn).await;
500
501 executor.spawn(io_task1).detach();
502 executor.spawn(io_task2).detach();
503
504 executor
505 .spawn(async move {
506 let request = server_incoming
507 .next()
508 .await
509 .unwrap()
510 .into_any()
511 .downcast::<TypedEnvelope<proto::Ping>>()
512 .unwrap();
513
514 server
515 .send(
516 server_to_client_conn_id,
517 proto::Error {
518 message: "message 1".to_string(),
519 },
520 )
521 .await
522 .unwrap();
523 server
524 .send(
525 server_to_client_conn_id,
526 proto::Error {
527 message: "message 2".to_string(),
528 },
529 )
530 .await
531 .unwrap();
532 server
533 .respond(request.receipt(), proto::Ack {})
534 .await
535 .unwrap();
536
537 // Prevent the connection from being dropped
538 server_incoming.next().await;
539 })
540 .detach();
541
542 let events = Arc::new(Mutex::new(Vec::new()));
543
544 let response = client.request(client_to_server_conn_id, proto::Ping {});
545 let response_task = executor.spawn({
546 let events = events.clone();
547 async move {
548 response.await.unwrap();
549 events.lock().push("response".to_string());
550 }
551 });
552
553 executor
554 .spawn({
555 let events = events.clone();
556 async move {
557 let incoming1 = client_incoming
558 .next()
559 .await
560 .unwrap()
561 .into_any()
562 .downcast::<TypedEnvelope<proto::Error>>()
563 .unwrap();
564 events.lock().push(incoming1.payload.message);
565 let incoming2 = client_incoming
566 .next()
567 .await
568 .unwrap()
569 .into_any()
570 .downcast::<TypedEnvelope<proto::Error>>()
571 .unwrap();
572 events.lock().push(incoming2.payload.message);
573
574 // Prevent the connection from being dropped
575 client_incoming.next().await;
576 }
577 })
578 .detach();
579
580 response_task.await;
581 assert_eq!(
582 &*events.lock(),
583 &[
584 "message 1".to_string(),
585 "message 2".to_string(),
586 "response".to_string()
587 ]
588 );
589 }
590
591 #[gpui::test(iterations = 10)]
592 async fn test_disconnect(cx: TestAppContext) {
593 let executor = cx.foreground();
594
595 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
596
597 let client = Peer::new();
598 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
599
600 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
601 executor
602 .spawn(async move {
603 io_handler.await.ok();
604 io_ended_tx.send(()).await.unwrap();
605 })
606 .detach();
607
608 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
609 executor
610 .spawn(async move {
611 incoming.next().await;
612 messages_ended_tx.send(()).await.unwrap();
613 })
614 .detach();
615
616 client.disconnect(connection_id);
617
618 io_ended_rx.recv().await;
619 messages_ended_rx.recv().await;
620 assert!(server_conn
621 .send(WebSocketMessage::Binary(vec![]))
622 .await
623 .is_err());
624 }
625
626 #[gpui::test(iterations = 10)]
627 async fn test_io_error(cx: TestAppContext) {
628 let executor = cx.foreground();
629 let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
630
631 let client = Peer::new();
632 let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
633 executor.spawn(io_handler).detach();
634 executor
635 .spawn(async move { incoming.next().await })
636 .detach();
637
638 let response = executor.spawn(client.request(connection_id, proto::Ping {}));
639 let _request = server_conn.rx.next().await.unwrap().unwrap();
640
641 drop(server_conn);
642 assert_eq!(
643 response.await.unwrap_err().to_string(),
644 "connection was closed"
645 );
646 }
647}