1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Context, Result};
3use async_lock::{Mutex, RwLock};
4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
5use futures::{
6 future::{self, BoxFuture, LocalBoxFuture},
7 FutureExt, Stream, StreamExt,
8};
9use postage::{
10 broadcast, mpsc,
11 prelude::{Sink as _, Stream as _},
12};
13use std::{
14 any::{Any, TypeId},
15 collections::{HashMap, HashSet},
16 fmt,
17 future::Future,
18 marker::PhantomData,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct PeerId(pub u32);
30
31type MessageHandler = Box<
32 dyn Send
33 + Sync
34 + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
35>;
36
37type ForegroundMessageHandler =
38 Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
39
40pub struct Receipt<T> {
41 sender_id: ConnectionId,
42 message_id: u32,
43 payload_type: PhantomData<T>,
44}
45
46pub struct TypedEnvelope<T> {
47 pub sender_id: ConnectionId,
48 pub original_sender_id: Option<PeerId>,
49 pub message_id: u32,
50 pub payload: T,
51}
52
53impl<T> TypedEnvelope<T> {
54 pub fn original_sender_id(&self) -> Result<PeerId> {
55 self.original_sender_id
56 .ok_or_else(|| anyhow!("missing original_sender_id"))
57 }
58}
59
60impl<T: RequestMessage> TypedEnvelope<T> {
61 pub fn receipt(&self) -> Receipt<T> {
62 Receipt {
63 sender_id: self.sender_id,
64 message_id: self.message_id,
65 payload_type: PhantomData,
66 }
67 }
68}
69
70pub type Router = RouterInternal<MessageHandler>;
71pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
72pub struct RouterInternal<H> {
73 message_handlers: Vec<H>,
74 handler_types: HashSet<TypeId>,
75}
76
77pub struct Peer {
78 connections: RwLock<HashMap<ConnectionId, Connection>>,
79 next_connection_id: AtomicU32,
80 incoming_messages: broadcast::Sender<Arc<dyn Any + Send + Sync>>,
81}
82
83#[derive(Clone)]
84struct Connection {
85 outgoing_tx: mpsc::Sender<proto::Envelope>,
86 next_message_id: Arc<AtomicU32>,
87 response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
88}
89
90impl Peer {
91 pub fn new() -> Arc<Self> {
92 Arc::new(Self {
93 connections: Default::default(),
94 next_connection_id: Default::default(),
95 incoming_messages: broadcast::channel(256).0,
96 })
97 }
98
99 pub async fn add_connection<Conn, H, Fut>(
100 self: &Arc<Self>,
101 conn: Conn,
102 router: Arc<RouterInternal<H>>,
103 ) -> (
104 ConnectionId,
105 impl Future<Output = anyhow::Result<()>> + Send,
106 impl Future<Output = ()>,
107 )
108 where
109 H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
110 Fut: Future<Output = ()>,
111 Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
112 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
113 + Send
114 + Unpin,
115 {
116 let (tx, rx) = conn.split();
117 let connection_id = ConnectionId(
118 self.next_connection_id
119 .fetch_add(1, atomic::Ordering::SeqCst),
120 );
121 let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64);
122 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
123 let connection = Connection {
124 outgoing_tx,
125 next_message_id: Default::default(),
126 response_channels: Default::default(),
127 };
128 let mut writer = MessageStream::new(tx);
129 let mut reader = MessageStream::new(rx);
130
131 let handle_io = async move {
132 loop {
133 let read_message = reader.read_message().fuse();
134 futures::pin_mut!(read_message);
135 loop {
136 futures::select_biased! {
137 incoming = read_message => match incoming {
138 Ok(incoming) => {
139 if incoming_tx.send(incoming).await.is_err() {
140 return Ok(());
141 }
142 break;
143 }
144 Err(error) => {
145 Err(error).context("received invalid RPC message")?;
146 }
147 },
148 outgoing = outgoing_rx.recv().fuse() => match outgoing {
149 Some(outgoing) => {
150 if let Err(result) = writer.write_message(&outgoing).await {
151 Err(result).context("failed to write RPC message")?;
152 }
153 }
154 None => return Ok(()),
155 }
156 }
157 }
158 }
159 };
160
161 let mut broadcast_incoming_messages = self.incoming_messages.clone();
162 let response_channels = connection.response_channels.clone();
163 let handle_messages = async move {
164 while let Some(envelope) = incoming_rx.recv().await {
165 if let Some(responding_to) = envelope.responding_to {
166 let channel = response_channels.lock().await.remove(&responding_to);
167 if let Some(mut tx) = channel {
168 tx.send(envelope).await.ok();
169 } else {
170 log::warn!("received RPC response to unknown request {}", responding_to);
171 }
172 } else {
173 router.handle(connection_id, envelope.clone()).await;
174 match proto::build_typed_envelope(connection_id, envelope) {
175 Ok(envelope) => {
176 broadcast_incoming_messages.send(envelope).await.ok();
177 }
178 Err(error) => log::error!("{}", error),
179 }
180 }
181 }
182 response_channels.lock().await.clear();
183 };
184
185 self.connections
186 .write()
187 .await
188 .insert(connection_id, connection);
189
190 (connection_id, handle_io, handle_messages)
191 }
192
193 pub async fn disconnect(&self, connection_id: ConnectionId) {
194 self.connections.write().await.remove(&connection_id);
195 }
196
197 pub async fn reset(&self) {
198 self.connections.write().await.clear();
199 }
200
201 pub fn subscribe<T: EnvelopedMessage>(&self) -> impl Stream<Item = Arc<TypedEnvelope<T>>> {
202 self.incoming_messages
203 .subscribe()
204 .filter_map(|envelope| future::ready(Arc::downcast(envelope).ok()))
205 }
206
207 pub fn request<T: RequestMessage>(
208 self: &Arc<Self>,
209 receiver_id: ConnectionId,
210 request: T,
211 ) -> impl Future<Output = Result<T::Response>> {
212 self.request_internal(None, receiver_id, request)
213 }
214
215 pub fn forward_request<T: RequestMessage>(
216 self: &Arc<Self>,
217 sender_id: ConnectionId,
218 receiver_id: ConnectionId,
219 request: T,
220 ) -> impl Future<Output = Result<T::Response>> {
221 self.request_internal(Some(sender_id), receiver_id, request)
222 }
223
224 pub fn request_internal<T: RequestMessage>(
225 self: &Arc<Self>,
226 original_sender_id: Option<ConnectionId>,
227 receiver_id: ConnectionId,
228 request: T,
229 ) -> impl Future<Output = Result<T::Response>> {
230 let this = self.clone();
231 let (tx, mut rx) = mpsc::channel(1);
232 async move {
233 let mut connection = this.connection(receiver_id).await?;
234 let message_id = connection
235 .next_message_id
236 .fetch_add(1, atomic::Ordering::SeqCst);
237 connection
238 .response_channels
239 .lock()
240 .await
241 .insert(message_id, tx);
242 connection
243 .outgoing_tx
244 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
245 .await
246 .map_err(|_| anyhow!("connection was closed"))?;
247 let response = rx
248 .recv()
249 .await
250 .ok_or_else(|| anyhow!("connection was closed"))?;
251 T::Response::from_envelope(response)
252 .ok_or_else(|| anyhow!("received response of the wrong type"))
253 }
254 }
255
256 pub fn send<T: EnvelopedMessage>(
257 self: &Arc<Self>,
258 receiver_id: ConnectionId,
259 message: T,
260 ) -> impl Future<Output = Result<()>> {
261 let this = self.clone();
262 async move {
263 let mut connection = this.connection(receiver_id).await?;
264 let message_id = connection
265 .next_message_id
266 .fetch_add(1, atomic::Ordering::SeqCst);
267 connection
268 .outgoing_tx
269 .send(message.into_envelope(message_id, None, None))
270 .await?;
271 Ok(())
272 }
273 }
274
275 pub fn forward_send<T: EnvelopedMessage>(
276 self: &Arc<Self>,
277 sender_id: ConnectionId,
278 receiver_id: ConnectionId,
279 message: T,
280 ) -> impl Future<Output = Result<()>> {
281 let this = self.clone();
282 async move {
283 let mut connection = this.connection(receiver_id).await?;
284 let message_id = connection
285 .next_message_id
286 .fetch_add(1, atomic::Ordering::SeqCst);
287 connection
288 .outgoing_tx
289 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
290 .await?;
291 Ok(())
292 }
293 }
294
295 pub fn respond<T: RequestMessage>(
296 self: &Arc<Self>,
297 receipt: Receipt<T>,
298 response: T::Response,
299 ) -> impl Future<Output = Result<()>> {
300 let this = self.clone();
301 async move {
302 let mut connection = this.connection(receipt.sender_id).await?;
303 let message_id = connection
304 .next_message_id
305 .fetch_add(1, atomic::Ordering::SeqCst);
306 connection
307 .outgoing_tx
308 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
309 .await?;
310 Ok(())
311 }
312 }
313
314 fn connection(
315 self: &Arc<Self>,
316 connection_id: ConnectionId,
317 ) -> impl Future<Output = Result<Connection>> {
318 let this = self.clone();
319 async move {
320 let connections = this.connections.read().await;
321 let connection = connections
322 .get(&connection_id)
323 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
324 Ok(connection.clone())
325 }
326 }
327}
328
329impl<H, Fut> RouterInternal<H>
330where
331 H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
332 Fut: Future<Output = ()>,
333{
334 pub fn new() -> Self {
335 Self {
336 message_handlers: Default::default(),
337 handler_types: Default::default(),
338 }
339 }
340
341 async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
342 let mut envelope = Some(message);
343 for handler in self.message_handlers.iter() {
344 if let Some(future) = handler(&mut envelope, connection_id) {
345 future.await;
346 return;
347 }
348 }
349 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
350 }
351}
352
353impl Router {
354 pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
355 where
356 T: EnvelopedMessage,
357 Fut: 'static + Send + Future<Output = Result<()>>,
358 F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
359 {
360 if !self.handler_types.insert(TypeId::of::<T>()) {
361 panic!("duplicate handler type");
362 }
363
364 self.message_handlers
365 .push(Box::new(move |envelope, connection_id| {
366 if envelope.as_ref().map_or(false, T::matches_envelope) {
367 let envelope = Option::take(envelope).unwrap();
368 let message_id = envelope.id;
369 let future = handler(TypedEnvelope {
370 sender_id: connection_id,
371 original_sender_id: envelope.original_sender_id.map(PeerId),
372 message_id,
373 payload: T::from_envelope(envelope).unwrap(),
374 });
375 Some(
376 async move {
377 if let Err(error) = future.await {
378 log::error!(
379 "error handling message {} {}: {:?}",
380 T::NAME,
381 message_id,
382 error
383 );
384 }
385 }
386 .boxed(),
387 )
388 } else {
389 None
390 }
391 }));
392 }
393}
394
395impl ForegroundRouter {
396 pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
397 where
398 T: EnvelopedMessage,
399 Fut: 'static + Future<Output = Result<()>>,
400 F: 'static + Fn(TypedEnvelope<T>) -> Fut,
401 {
402 if !self.handler_types.insert(TypeId::of::<T>()) {
403 panic!("duplicate handler type");
404 }
405
406 self.message_handlers
407 .push(Box::new(move |envelope, connection_id| {
408 if envelope.as_ref().map_or(false, T::matches_envelope) {
409 let envelope = Option::take(envelope).unwrap();
410 let message_id = envelope.id;
411 let future = handler(TypedEnvelope {
412 sender_id: connection_id,
413 original_sender_id: envelope.original_sender_id.map(PeerId),
414 message_id,
415 payload: T::from_envelope(envelope).unwrap(),
416 });
417 Some(
418 async move {
419 if let Err(error) = future.await {
420 log::error!(
421 "error handling message {} {}: {:?}",
422 T::NAME,
423 message_id,
424 error
425 );
426 }
427 }
428 .boxed_local(),
429 )
430 } else {
431 None
432 }
433 }));
434 }
435}
436
437impl<T> Clone for Receipt<T> {
438 fn clone(&self) -> Self {
439 Self {
440 sender_id: self.sender_id,
441 message_id: self.message_id,
442 payload_type: PhantomData,
443 }
444 }
445}
446
447impl<T> Copy for Receipt<T> {}
448
449impl fmt::Display for ConnectionId {
450 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
451 self.0.fmt(f)
452 }
453}
454
455impl fmt::Display for PeerId {
456 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
457 self.0.fmt(f)
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464 use crate::test;
465
466 #[test]
467 fn test_request_response() {
468 smol::block_on(async move {
469 // create 2 clients connected to 1 server
470 let server = Peer::new();
471 let client1 = Peer::new();
472 let client2 = Peer::new();
473
474 let mut router = Router::new();
475 router.add_message_handler({
476 let server = server.clone();
477 move |envelope: TypedEnvelope<proto::Auth>| {
478 let server = server.clone();
479 async move {
480 let receipt = envelope.receipt();
481 let message = envelope.payload;
482 server
483 .respond(
484 receipt,
485 match message.user_id {
486 1 => {
487 assert_eq!(message.access_token, "access-token-1");
488 proto::AuthResponse {
489 credentials_valid: true,
490 }
491 }
492 2 => {
493 assert_eq!(message.access_token, "access-token-2");
494 proto::AuthResponse {
495 credentials_valid: false,
496 }
497 }
498 _ => {
499 panic!("unexpected user id {}", message.user_id);
500 }
501 },
502 )
503 .await
504 }
505 }
506 });
507
508 router.add_message_handler({
509 let server = server.clone();
510 move |envelope: TypedEnvelope<proto::OpenBuffer>| {
511 let server = server.clone();
512 async move {
513 let receipt = envelope.receipt();
514 let message = envelope.payload;
515 server
516 .respond(
517 receipt,
518 match message.path.as_str() {
519 "path/one" => {
520 assert_eq!(message.worktree_id, 1);
521 proto::OpenBufferResponse {
522 buffer: Some(proto::Buffer {
523 id: 101,
524 content: "path/one content".to_string(),
525 history: vec![],
526 selections: vec![],
527 }),
528 }
529 }
530 "path/two" => {
531 assert_eq!(message.worktree_id, 2);
532 proto::OpenBufferResponse {
533 buffer: Some(proto::Buffer {
534 id: 102,
535 content: "path/two content".to_string(),
536 history: vec![],
537 selections: vec![],
538 }),
539 }
540 }
541 _ => {
542 panic!("unexpected path {}", message.path);
543 }
544 },
545 )
546 .await
547 }
548 }
549 });
550 let router = Arc::new(router);
551
552 let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
553 let (client1_conn_id, io_task1, msg_task1) = client1
554 .add_connection(client1_to_server_conn, router.clone())
555 .await;
556 let (_, io_task2, msg_task2) = server
557 .add_connection(server_to_client_1_conn, router.clone())
558 .await;
559
560 let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
561 let (client2_conn_id, io_task3, msg_task3) = client2
562 .add_connection(client2_to_server_conn, router.clone())
563 .await;
564 let (_, io_task4, msg_task4) = server
565 .add_connection(server_to_client_2_conn, router.clone())
566 .await;
567
568 smol::spawn(io_task1).detach();
569 smol::spawn(io_task2).detach();
570 smol::spawn(io_task3).detach();
571 smol::spawn(io_task4).detach();
572 smol::spawn(msg_task1).detach();
573 smol::spawn(msg_task2).detach();
574 smol::spawn(msg_task3).detach();
575 smol::spawn(msg_task4).detach();
576
577 assert_eq!(
578 client1
579 .request(
580 client1_conn_id,
581 proto::Auth {
582 user_id: 1,
583 access_token: "access-token-1".to_string(),
584 },
585 )
586 .await
587 .unwrap(),
588 proto::AuthResponse {
589 credentials_valid: true,
590 }
591 );
592
593 assert_eq!(
594 client2
595 .request(
596 client2_conn_id,
597 proto::Auth {
598 user_id: 2,
599 access_token: "access-token-2".to_string(),
600 },
601 )
602 .await
603 .unwrap(),
604 proto::AuthResponse {
605 credentials_valid: false,
606 }
607 );
608
609 assert_eq!(
610 client1
611 .request(
612 client1_conn_id,
613 proto::OpenBuffer {
614 worktree_id: 1,
615 path: "path/one".to_string(),
616 },
617 )
618 .await
619 .unwrap(),
620 proto::OpenBufferResponse {
621 buffer: Some(proto::Buffer {
622 id: 101,
623 content: "path/one content".to_string(),
624 history: vec![],
625 selections: vec![],
626 }),
627 }
628 );
629
630 assert_eq!(
631 client2
632 .request(
633 client2_conn_id,
634 proto::OpenBuffer {
635 worktree_id: 2,
636 path: "path/two".to_string(),
637 },
638 )
639 .await
640 .unwrap(),
641 proto::OpenBufferResponse {
642 buffer: Some(proto::Buffer {
643 id: 102,
644 content: "path/two content".to_string(),
645 history: vec![],
646 selections: vec![],
647 }),
648 }
649 );
650
651 client1.disconnect(client1_conn_id).await;
652 client2.disconnect(client1_conn_id).await;
653 });
654 }
655
656 #[test]
657 fn test_disconnect() {
658 smol::block_on(async move {
659 let (client_conn, mut server_conn) = test::Channel::bidirectional();
660
661 let client = Peer::new();
662 let router = Arc::new(Router::new());
663 let (connection_id, io_handler, message_handler) =
664 client.add_connection(client_conn, router).await;
665
666 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
667 smol::spawn(async move {
668 io_handler.await.ok();
669 io_ended_tx.send(()).await.unwrap();
670 })
671 .detach();
672
673 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
674 smol::spawn(async move {
675 message_handler.await;
676 messages_ended_tx.send(()).await.unwrap();
677 })
678 .detach();
679
680 client.disconnect(connection_id).await;
681
682 io_ended_rx.recv().await;
683 messages_ended_rx.recv().await;
684 assert!(
685 futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
686 .await
687 .is_err()
688 );
689 });
690 }
691
692 #[test]
693 fn test_io_error() {
694 smol::block_on(async move {
695 let (client_conn, server_conn) = test::Channel::bidirectional();
696 drop(server_conn);
697
698 let client = Peer::new();
699 let router = Arc::new(Router::new());
700 let (connection_id, io_handler, message_handler) =
701 client.add_connection(client_conn, router).await;
702 smol::spawn(io_handler).detach();
703 smol::spawn(message_handler).detach();
704
705 let err = client
706 .request(
707 connection_id,
708 proto::Auth {
709 user_id: 42,
710 access_token: "token".to_string(),
711 },
712 )
713 .await
714 .unwrap_err();
715 assert_eq!(err.to_string(), "connection was closed");
716 });
717 }
718}