1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Context, Result};
3use async_lock::{Mutex, RwLock};
4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
5use futures::{
6 future::{self, BoxFuture, LocalBoxFuture},
7 FutureExt, Stream, StreamExt,
8};
9use postage::{
10 broadcast, mpsc,
11 prelude::{Sink as _, Stream as _},
12};
13use std::{
14 any::{Any, TypeId},
15 collections::{HashMap, HashSet},
16 fmt,
17 future::Future,
18 marker::PhantomData,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct PeerId(pub u32);
30
31type MessageHandler = Box<
32 dyn Send
33 + Sync
34 + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
35>;
36
37type ForegroundMessageHandler =
38 Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
39
40pub struct Receipt<T> {
41 pub sender_id: ConnectionId,
42 pub message_id: u32,
43 payload_type: PhantomData<T>,
44}
45
46pub struct TypedEnvelope<T> {
47 pub sender_id: ConnectionId,
48 pub original_sender_id: Option<PeerId>,
49 pub message_id: u32,
50 pub payload: T,
51}
52
53impl<T> TypedEnvelope<T> {
54 pub fn original_sender_id(&self) -> Result<PeerId> {
55 self.original_sender_id
56 .ok_or_else(|| anyhow!("missing original_sender_id"))
57 }
58}
59
60impl<T: RequestMessage> TypedEnvelope<T> {
61 pub fn receipt(&self) -> Receipt<T> {
62 Receipt {
63 sender_id: self.sender_id,
64 message_id: self.message_id,
65 payload_type: PhantomData,
66 }
67 }
68}
69
70pub type Router = RouterInternal<MessageHandler>;
71pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
72pub struct RouterInternal<H> {
73 message_handlers: Vec<H>,
74 handler_types: HashSet<TypeId>,
75}
76
77pub struct Peer {
78 connections: RwLock<HashMap<ConnectionId, Connection>>,
79 next_connection_id: AtomicU32,
80 incoming_messages: broadcast::Sender<Arc<dyn Any + Send + Sync>>,
81}
82
83#[derive(Clone)]
84struct Connection {
85 outgoing_tx: mpsc::Sender<proto::Envelope>,
86 next_message_id: Arc<AtomicU32>,
87 response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
88}
89
90impl Peer {
91 pub fn new() -> Arc<Self> {
92 Arc::new(Self {
93 connections: Default::default(),
94 next_connection_id: Default::default(),
95 incoming_messages: broadcast::channel(256).0,
96 })
97 }
98
99 pub async fn add_connection<Conn, H, Fut>(
100 self: &Arc<Self>,
101 conn: Conn,
102 router: Arc<RouterInternal<H>>,
103 ) -> (
104 ConnectionId,
105 impl Future<Output = anyhow::Result<()>> + Send,
106 impl Future<Output = ()>,
107 )
108 where
109 H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
110 Fut: Future<Output = ()>,
111 Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
112 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
113 + Send
114 + Unpin,
115 {
116 let (tx, rx) = conn.split();
117 let connection_id = ConnectionId(
118 self.next_connection_id
119 .fetch_add(1, atomic::Ordering::SeqCst),
120 );
121 let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64);
122 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
123 let connection = Connection {
124 outgoing_tx,
125 next_message_id: Default::default(),
126 response_channels: Default::default(),
127 };
128 let mut writer = MessageStream::new(tx);
129 let mut reader = MessageStream::new(rx);
130
131 let handle_io = async move {
132 loop {
133 let read_message = reader.read_message().fuse();
134 futures::pin_mut!(read_message);
135 loop {
136 futures::select_biased! {
137 incoming = read_message => match incoming {
138 Ok(incoming) => {
139 if incoming_tx.send(incoming).await.is_err() {
140 return Ok(());
141 }
142 break;
143 }
144 Err(error) => {
145 Err(error).context("received invalid RPC message")?;
146 }
147 },
148 outgoing = outgoing_rx.recv().fuse() => match outgoing {
149 Some(outgoing) => {
150 if let Err(result) = writer.write_message(&outgoing).await {
151 Err(result).context("failed to write RPC message")?;
152 }
153 }
154 None => return Ok(()),
155 }
156 }
157 }
158 }
159 };
160
161 let mut broadcast_incoming_messages = self.incoming_messages.clone();
162 let response_channels = connection.response_channels.clone();
163 let handle_messages = async move {
164 while let Some(envelope) = incoming_rx.recv().await {
165 if let Some(responding_to) = envelope.responding_to {
166 let channel = response_channels.lock().await.remove(&responding_to);
167 if let Some(mut tx) = channel {
168 tx.send(envelope).await.ok();
169 } else {
170 log::warn!("received RPC response to unknown request {}", responding_to);
171 }
172 } else {
173 router.handle(connection_id, envelope.clone()).await;
174 if let Some(envelope) = proto::build_typed_envelope(connection_id, envelope) {
175 broadcast_incoming_messages.send(Arc::from(envelope)).await.ok();
176 } else {
177 log::error!("unable to construct a typed envelope");
178 }
179 }
180 }
181 response_channels.lock().await.clear();
182 };
183
184 self.connections
185 .write()
186 .await
187 .insert(connection_id, connection);
188
189 (connection_id, handle_io, handle_messages)
190 }
191
192 pub async fn disconnect(&self, connection_id: ConnectionId) {
193 self.connections.write().await.remove(&connection_id);
194 }
195
196 pub async fn reset(&self) {
197 self.connections.write().await.clear();
198 }
199
200 pub fn subscribe<T: EnvelopedMessage>(&self) -> impl Stream<Item = Arc<TypedEnvelope<T>>> {
201 self.incoming_messages
202 .subscribe()
203 .filter_map(|envelope| future::ready(Arc::downcast(envelope).ok()))
204 }
205
206 pub fn request<T: RequestMessage>(
207 self: &Arc<Self>,
208 receiver_id: ConnectionId,
209 request: T,
210 ) -> impl Future<Output = Result<T::Response>> {
211 self.request_internal(None, receiver_id, request)
212 }
213
214 pub fn forward_request<T: RequestMessage>(
215 self: &Arc<Self>,
216 sender_id: ConnectionId,
217 receiver_id: ConnectionId,
218 request: T,
219 ) -> impl Future<Output = Result<T::Response>> {
220 self.request_internal(Some(sender_id), receiver_id, request)
221 }
222
223 pub fn request_internal<T: RequestMessage>(
224 self: &Arc<Self>,
225 original_sender_id: Option<ConnectionId>,
226 receiver_id: ConnectionId,
227 request: T,
228 ) -> impl Future<Output = Result<T::Response>> {
229 let this = self.clone();
230 let (tx, mut rx) = mpsc::channel(1);
231 async move {
232 let mut connection = this.connection(receiver_id).await?;
233 let message_id = connection
234 .next_message_id
235 .fetch_add(1, atomic::Ordering::SeqCst);
236 connection
237 .response_channels
238 .lock()
239 .await
240 .insert(message_id, tx);
241 connection
242 .outgoing_tx
243 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
244 .await
245 .map_err(|_| anyhow!("connection was closed"))?;
246 let response = rx
247 .recv()
248 .await
249 .ok_or_else(|| anyhow!("connection was closed"))?;
250 T::Response::from_envelope(response)
251 .ok_or_else(|| anyhow!("received response of the wrong type"))
252 }
253 }
254
255 pub fn send<T: EnvelopedMessage>(
256 self: &Arc<Self>,
257 receiver_id: ConnectionId,
258 message: T,
259 ) -> impl Future<Output = Result<()>> {
260 let this = self.clone();
261 async move {
262 let mut connection = this.connection(receiver_id).await?;
263 let message_id = connection
264 .next_message_id
265 .fetch_add(1, atomic::Ordering::SeqCst);
266 connection
267 .outgoing_tx
268 .send(message.into_envelope(message_id, None, None))
269 .await?;
270 Ok(())
271 }
272 }
273
274 pub fn forward_send<T: EnvelopedMessage>(
275 self: &Arc<Self>,
276 sender_id: ConnectionId,
277 receiver_id: ConnectionId,
278 message: T,
279 ) -> impl Future<Output = Result<()>> {
280 let this = self.clone();
281 async move {
282 let mut connection = this.connection(receiver_id).await?;
283 let message_id = connection
284 .next_message_id
285 .fetch_add(1, atomic::Ordering::SeqCst);
286 connection
287 .outgoing_tx
288 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
289 .await?;
290 Ok(())
291 }
292 }
293
294 pub fn respond<T: RequestMessage>(
295 self: &Arc<Self>,
296 receipt: Receipt<T>,
297 response: T::Response,
298 ) -> impl Future<Output = Result<()>> {
299 let this = self.clone();
300 async move {
301 let mut connection = this.connection(receipt.sender_id).await?;
302 let message_id = connection
303 .next_message_id
304 .fetch_add(1, atomic::Ordering::SeqCst);
305 connection
306 .outgoing_tx
307 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
308 .await?;
309 Ok(())
310 }
311 }
312
313 fn connection(
314 self: &Arc<Self>,
315 connection_id: ConnectionId,
316 ) -> impl Future<Output = Result<Connection>> {
317 let this = self.clone();
318 async move {
319 let connections = this.connections.read().await;
320 let connection = connections
321 .get(&connection_id)
322 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
323 Ok(connection.clone())
324 }
325 }
326}
327
328impl<H, Fut> RouterInternal<H>
329where
330 H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
331 Fut: Future<Output = ()>,
332{
333 pub fn new() -> Self {
334 Self {
335 message_handlers: Default::default(),
336 handler_types: Default::default(),
337 }
338 }
339
340 async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
341 let mut envelope = Some(message);
342 for handler in self.message_handlers.iter() {
343 if let Some(future) = handler(&mut envelope, connection_id) {
344 future.await;
345 return;
346 }
347 }
348 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
349 }
350}
351
352impl Router {
353 pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
354 where
355 T: EnvelopedMessage,
356 Fut: 'static + Send + Future<Output = Result<()>>,
357 F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
358 {
359 if !self.handler_types.insert(TypeId::of::<T>()) {
360 panic!("duplicate handler type");
361 }
362
363 self.message_handlers
364 .push(Box::new(move |envelope, connection_id| {
365 if envelope.as_ref().map_or(false, T::matches_envelope) {
366 let envelope = Option::take(envelope).unwrap();
367 let message_id = envelope.id;
368 let future = handler(TypedEnvelope {
369 sender_id: connection_id,
370 original_sender_id: envelope.original_sender_id.map(PeerId),
371 message_id,
372 payload: T::from_envelope(envelope).unwrap(),
373 });
374 Some(
375 async move {
376 if let Err(error) = future.await {
377 log::error!(
378 "error handling message {} {}: {:?}",
379 T::NAME,
380 message_id,
381 error
382 );
383 }
384 }
385 .boxed(),
386 )
387 } else {
388 None
389 }
390 }));
391 }
392}
393
394impl ForegroundRouter {
395 pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
396 where
397 T: EnvelopedMessage,
398 Fut: 'static + Future<Output = Result<()>>,
399 F: 'static + Fn(TypedEnvelope<T>) -> Fut,
400 {
401 if !self.handler_types.insert(TypeId::of::<T>()) {
402 panic!("duplicate handler type");
403 }
404
405 self.message_handlers
406 .push(Box::new(move |envelope, connection_id| {
407 if envelope.as_ref().map_or(false, T::matches_envelope) {
408 let envelope = Option::take(envelope).unwrap();
409 let message_id = envelope.id;
410 let future = handler(TypedEnvelope {
411 sender_id: connection_id,
412 original_sender_id: envelope.original_sender_id.map(PeerId),
413 message_id,
414 payload: T::from_envelope(envelope).unwrap(),
415 });
416 Some(
417 async move {
418 if let Err(error) = future.await {
419 log::error!(
420 "error handling message {} {}: {:?}",
421 T::NAME,
422 message_id,
423 error
424 );
425 }
426 }
427 .boxed_local(),
428 )
429 } else {
430 None
431 }
432 }));
433 }
434}
435
436impl<T> Clone for Receipt<T> {
437 fn clone(&self) -> Self {
438 Self {
439 sender_id: self.sender_id,
440 message_id: self.message_id,
441 payload_type: PhantomData,
442 }
443 }
444}
445
446impl<T> Copy for Receipt<T> {}
447
448impl fmt::Display for ConnectionId {
449 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
450 self.0.fmt(f)
451 }
452}
453
454impl fmt::Display for PeerId {
455 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
456 self.0.fmt(f)
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463 use crate::test;
464
465 #[test]
466 fn test_request_response() {
467 smol::block_on(async move {
468 // create 2 clients connected to 1 server
469 let server = Peer::new();
470 let client1 = Peer::new();
471 let client2 = Peer::new();
472
473 let mut router = Router::new();
474 router.add_message_handler({
475 let server = server.clone();
476 move |envelope: TypedEnvelope<proto::Auth>| {
477 let server = server.clone();
478 async move {
479 let receipt = envelope.receipt();
480 let message = envelope.payload;
481 server
482 .respond(
483 receipt,
484 match message.user_id {
485 1 => {
486 assert_eq!(message.access_token, "access-token-1");
487 proto::AuthResponse {
488 credentials_valid: true,
489 }
490 }
491 2 => {
492 assert_eq!(message.access_token, "access-token-2");
493 proto::AuthResponse {
494 credentials_valid: false,
495 }
496 }
497 _ => {
498 panic!("unexpected user id {}", message.user_id);
499 }
500 },
501 )
502 .await
503 }
504 }
505 });
506
507 router.add_message_handler({
508 let server = server.clone();
509 move |envelope: TypedEnvelope<proto::OpenBuffer>| {
510 let server = server.clone();
511 async move {
512 let receipt = envelope.receipt();
513 let message = envelope.payload;
514 server
515 .respond(
516 receipt,
517 match message.path.as_str() {
518 "path/one" => {
519 assert_eq!(message.worktree_id, 1);
520 proto::OpenBufferResponse {
521 buffer: Some(proto::Buffer {
522 id: 101,
523 content: "path/one content".to_string(),
524 history: vec![],
525 selections: vec![],
526 }),
527 }
528 }
529 "path/two" => {
530 assert_eq!(message.worktree_id, 2);
531 proto::OpenBufferResponse {
532 buffer: Some(proto::Buffer {
533 id: 102,
534 content: "path/two content".to_string(),
535 history: vec![],
536 selections: vec![],
537 }),
538 }
539 }
540 _ => {
541 panic!("unexpected path {}", message.path);
542 }
543 },
544 )
545 .await
546 }
547 }
548 });
549 let router = Arc::new(router);
550
551 let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
552 let (client1_conn_id, io_task1, msg_task1) = client1
553 .add_connection(client1_to_server_conn, router.clone())
554 .await;
555 let (_, io_task2, msg_task2) = server
556 .add_connection(server_to_client_1_conn, router.clone())
557 .await;
558
559 let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
560 let (client2_conn_id, io_task3, msg_task3) = client2
561 .add_connection(client2_to_server_conn, router.clone())
562 .await;
563 let (_, io_task4, msg_task4) = server
564 .add_connection(server_to_client_2_conn, router.clone())
565 .await;
566
567 smol::spawn(io_task1).detach();
568 smol::spawn(io_task2).detach();
569 smol::spawn(io_task3).detach();
570 smol::spawn(io_task4).detach();
571 smol::spawn(msg_task1).detach();
572 smol::spawn(msg_task2).detach();
573 smol::spawn(msg_task3).detach();
574 smol::spawn(msg_task4).detach();
575
576 assert_eq!(
577 client1
578 .request(
579 client1_conn_id,
580 proto::Auth {
581 user_id: 1,
582 access_token: "access-token-1".to_string(),
583 },
584 )
585 .await
586 .unwrap(),
587 proto::AuthResponse {
588 credentials_valid: true,
589 }
590 );
591
592 assert_eq!(
593 client2
594 .request(
595 client2_conn_id,
596 proto::Auth {
597 user_id: 2,
598 access_token: "access-token-2".to_string(),
599 },
600 )
601 .await
602 .unwrap(),
603 proto::AuthResponse {
604 credentials_valid: false,
605 }
606 );
607
608 assert_eq!(
609 client1
610 .request(
611 client1_conn_id,
612 proto::OpenBuffer {
613 worktree_id: 1,
614 path: "path/one".to_string(),
615 },
616 )
617 .await
618 .unwrap(),
619 proto::OpenBufferResponse {
620 buffer: Some(proto::Buffer {
621 id: 101,
622 content: "path/one content".to_string(),
623 history: vec![],
624 selections: vec![],
625 }),
626 }
627 );
628
629 assert_eq!(
630 client2
631 .request(
632 client2_conn_id,
633 proto::OpenBuffer {
634 worktree_id: 2,
635 path: "path/two".to_string(),
636 },
637 )
638 .await
639 .unwrap(),
640 proto::OpenBufferResponse {
641 buffer: Some(proto::Buffer {
642 id: 102,
643 content: "path/two content".to_string(),
644 history: vec![],
645 selections: vec![],
646 }),
647 }
648 );
649
650 client1.disconnect(client1_conn_id).await;
651 client2.disconnect(client1_conn_id).await;
652 });
653 }
654
655 #[test]
656 fn test_disconnect() {
657 smol::block_on(async move {
658 let (client_conn, mut server_conn) = test::Channel::bidirectional();
659
660 let client = Peer::new();
661 let router = Arc::new(Router::new());
662 let (connection_id, io_handler, message_handler) =
663 client.add_connection(client_conn, router).await;
664
665 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
666 smol::spawn(async move {
667 io_handler.await.ok();
668 io_ended_tx.send(()).await.unwrap();
669 })
670 .detach();
671
672 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
673 smol::spawn(async move {
674 message_handler.await;
675 messages_ended_tx.send(()).await.unwrap();
676 })
677 .detach();
678
679 client.disconnect(connection_id).await;
680
681 io_ended_rx.recv().await;
682 messages_ended_rx.recv().await;
683 assert!(
684 futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
685 .await
686 .is_err()
687 );
688 });
689 }
690
691 #[test]
692 fn test_io_error() {
693 smol::block_on(async move {
694 let (client_conn, server_conn) = test::Channel::bidirectional();
695 drop(server_conn);
696
697 let client = Peer::new();
698 let router = Arc::new(Router::new());
699 let (connection_id, io_handler, message_handler) =
700 client.add_connection(client_conn, router).await;
701 smol::spawn(io_handler).detach();
702 smol::spawn(message_handler).detach();
703
704 let err = client
705 .request(
706 connection_id,
707 proto::Auth {
708 user_id: 42,
709 access_token: "token".to_string(),
710 },
711 )
712 .await
713 .unwrap_err();
714 assert_eq!(err.to_string(), "connection was closed");
715 });
716 }
717}