1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Context, Result};
3use async_lock::{Mutex, RwLock};
4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
5use futures::{
6 future::{BoxFuture, LocalBoxFuture},
7 stream::{SplitSink, SplitStream},
8 FutureExt, StreamExt,
9};
10use postage::{
11 mpsc,
12 prelude::{Sink, Stream},
13};
14use std::{
15 any::TypeId,
16 collections::{HashMap, HashSet},
17 fmt,
18 future::Future,
19 marker::PhantomData,
20 sync::{
21 atomic::{self, AtomicU32},
22 Arc,
23 },
24};
25
26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub struct ConnectionId(pub u32);
28
29#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
30pub struct PeerId(pub u32);
31
32type MessageHandler = Box<
33 dyn Send
34 + Sync
35 + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
36>;
37
38type ForegroundMessageHandler =
39 Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
40
41pub struct Receipt<T> {
42 sender_id: ConnectionId,
43 message_id: u32,
44 payload_type: PhantomData<T>,
45}
46
47pub struct TypedEnvelope<T> {
48 pub sender_id: ConnectionId,
49 original_sender_id: Option<PeerId>,
50 pub message_id: u32,
51 pub payload: T,
52}
53
54impl<T> TypedEnvelope<T> {
55 pub fn original_sender_id(&self) -> Result<PeerId> {
56 self.original_sender_id
57 .ok_or_else(|| anyhow!("missing original_sender_id"))
58 }
59}
60
61impl<T: RequestMessage> TypedEnvelope<T> {
62 pub fn receipt(&self) -> Receipt<T> {
63 Receipt {
64 sender_id: self.sender_id,
65 message_id: self.message_id,
66 payload_type: PhantomData,
67 }
68 }
69}
70
71pub type Router = RouterInternal<MessageHandler>;
72pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
73pub struct RouterInternal<H> {
74 message_handlers: Vec<H>,
75 handler_types: HashSet<TypeId>,
76}
77
78pub struct Peer {
79 connections: RwLock<HashMap<ConnectionId, Connection>>,
80 next_connection_id: AtomicU32,
81}
82
83#[derive(Clone)]
84struct Connection {
85 outgoing_tx: mpsc::Sender<proto::Envelope>,
86 next_message_id: Arc<AtomicU32>,
87 response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
88}
89
90pub struct IOHandler<W, R> {
91 connection_id: ConnectionId,
92 incoming_tx: mpsc::Sender<proto::Envelope>,
93 outgoing_rx: mpsc::Receiver<proto::Envelope>,
94 writer: MessageStream<W>,
95 reader: MessageStream<R>,
96}
97
98impl Peer {
99 pub fn new() -> Arc<Self> {
100 Arc::new(Self {
101 connections: Default::default(),
102 next_connection_id: Default::default(),
103 })
104 }
105
106 pub async fn add_connection<Conn, H, Fut>(
107 self: &Arc<Self>,
108 conn: Conn,
109 router: Arc<RouterInternal<H>>,
110 ) -> (
111 ConnectionId,
112 IOHandler<SplitSink<Conn, WebSocketMessage>, SplitStream<Conn>>,
113 impl Future<Output = anyhow::Result<()>>,
114 )
115 where
116 H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
117 Fut: Future<Output = ()>,
118 Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
119 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
120 + Unpin,
121 {
122 let (tx, rx) = conn.split();
123 let connection_id = ConnectionId(
124 self.next_connection_id
125 .fetch_add(1, atomic::Ordering::SeqCst),
126 );
127 let (incoming_tx, mut incoming_rx) = mpsc::channel(64);
128 let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
129 let connection = Connection {
130 outgoing_tx,
131 next_message_id: Default::default(),
132 response_channels: Default::default(),
133 };
134 let handle_io = IOHandler {
135 connection_id,
136 outgoing_rx,
137 incoming_tx,
138 writer: MessageStream::new(tx),
139 reader: MessageStream::new(rx),
140 };
141
142 let response_channels = connection.response_channels.clone();
143 let handle_messages = async move {
144 while let Some(message) = incoming_rx.recv().await {
145 if let Some(responding_to) = message.responding_to {
146 let channel = response_channels.lock().await.remove(&responding_to);
147 if let Some(mut tx) = channel {
148 tx.send(message).await.ok();
149 } else {
150 log::warn!("received RPC response to unknown request {}", responding_to);
151 }
152 } else {
153 router.handle(connection_id, message).await;
154 }
155 }
156 response_channels.lock().await.clear();
157 Ok(())
158 };
159
160 self.connections
161 .write()
162 .await
163 .insert(connection_id, connection);
164
165 (connection_id, handle_io, handle_messages)
166 }
167
168 pub async fn disconnect(&self, connection_id: ConnectionId) {
169 self.connections.write().await.remove(&connection_id);
170 }
171
172 pub async fn reset(&self) {
173 self.connections.write().await.clear();
174 }
175
176 pub fn request<T: RequestMessage>(
177 self: &Arc<Self>,
178 receiver_id: ConnectionId,
179 request: T,
180 ) -> impl Future<Output = Result<T::Response>> {
181 self.request_internal(None, receiver_id, request)
182 }
183
184 pub fn forward_request<T: RequestMessage>(
185 self: &Arc<Self>,
186 sender_id: ConnectionId,
187 receiver_id: ConnectionId,
188 request: T,
189 ) -> impl Future<Output = Result<T::Response>> {
190 self.request_internal(Some(sender_id), receiver_id, request)
191 }
192
193 pub fn request_internal<T: RequestMessage>(
194 self: &Arc<Self>,
195 original_sender_id: Option<ConnectionId>,
196 receiver_id: ConnectionId,
197 request: T,
198 ) -> impl Future<Output = Result<T::Response>> {
199 let this = self.clone();
200 let (tx, mut rx) = mpsc::channel(1);
201 async move {
202 let mut connection = this.connection(receiver_id).await?;
203 let message_id = connection
204 .next_message_id
205 .fetch_add(1, atomic::Ordering::SeqCst);
206 connection
207 .response_channels
208 .lock()
209 .await
210 .insert(message_id, tx);
211 connection
212 .outgoing_tx
213 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
214 .await?;
215 let response = rx
216 .recv()
217 .await
218 .ok_or_else(|| anyhow!("connection was closed"))?;
219 T::Response::from_envelope(response)
220 .ok_or_else(|| anyhow!("received response of the wrong type"))
221 }
222 }
223
224 pub fn send<T: EnvelopedMessage>(
225 self: &Arc<Self>,
226 receiver_id: ConnectionId,
227 message: T,
228 ) -> impl Future<Output = Result<()>> {
229 let this = self.clone();
230 async move {
231 let mut connection = this.connection(receiver_id).await?;
232 let message_id = connection
233 .next_message_id
234 .fetch_add(1, atomic::Ordering::SeqCst);
235 connection
236 .outgoing_tx
237 .send(message.into_envelope(message_id, None, None))
238 .await?;
239 Ok(())
240 }
241 }
242
243 pub fn forward_send<T: EnvelopedMessage>(
244 self: &Arc<Self>,
245 sender_id: ConnectionId,
246 receiver_id: ConnectionId,
247 message: T,
248 ) -> impl Future<Output = Result<()>> {
249 let this = self.clone();
250 async move {
251 let mut connection = this.connection(receiver_id).await?;
252 let message_id = connection
253 .next_message_id
254 .fetch_add(1, atomic::Ordering::SeqCst);
255 connection
256 .outgoing_tx
257 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
258 .await?;
259 Ok(())
260 }
261 }
262
263 pub fn respond<T: RequestMessage>(
264 self: &Arc<Self>,
265 receipt: Receipt<T>,
266 response: T::Response,
267 ) -> impl Future<Output = Result<()>> {
268 let this = self.clone();
269 async move {
270 let mut connection = this.connection(receipt.sender_id).await?;
271 let message_id = connection
272 .next_message_id
273 .fetch_add(1, atomic::Ordering::SeqCst);
274 connection
275 .outgoing_tx
276 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
277 .await?;
278 Ok(())
279 }
280 }
281
282 fn connection(
283 self: &Arc<Self>,
284 connection_id: ConnectionId,
285 ) -> impl Future<Output = Result<Connection>> {
286 let this = self.clone();
287 async move {
288 let connections = this.connections.read().await;
289 let connection = connections
290 .get(&connection_id)
291 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
292 Ok(connection.clone())
293 }
294 }
295}
296
297impl<H, Fut> RouterInternal<H>
298where
299 H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
300 Fut: Future<Output = ()>,
301{
302 pub fn new() -> Self {
303 Self {
304 message_handlers: Default::default(),
305 handler_types: Default::default(),
306 }
307 }
308
309 async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
310 let mut envelope = Some(message);
311 for handler in self.message_handlers.iter() {
312 if let Some(future) = handler(&mut envelope, connection_id) {
313 future.await;
314 return;
315 }
316 }
317 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
318 }
319}
320
321impl Router {
322 pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
323 where
324 T: EnvelopedMessage,
325 Fut: 'static + Send + Future<Output = Result<()>>,
326 F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
327 {
328 if !self.handler_types.insert(TypeId::of::<T>()) {
329 panic!("duplicate handler type");
330 }
331
332 self.message_handlers
333 .push(Box::new(move |envelope, connection_id| {
334 if envelope.as_ref().map_or(false, T::matches_envelope) {
335 let envelope = Option::take(envelope).unwrap();
336 let message_id = envelope.id;
337 let future = handler(TypedEnvelope {
338 sender_id: connection_id,
339 original_sender_id: envelope.original_sender_id.map(PeerId),
340 message_id,
341 payload: T::from_envelope(envelope).unwrap(),
342 });
343 Some(
344 async move {
345 if let Err(error) = future.await {
346 log::error!(
347 "error handling message {} {}: {:?}",
348 T::NAME,
349 message_id,
350 error
351 );
352 }
353 }
354 .boxed(),
355 )
356 } else {
357 None
358 }
359 }));
360 }
361}
362
363impl ForegroundRouter {
364 pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
365 where
366 T: EnvelopedMessage,
367 Fut: 'static + Future<Output = Result<()>>,
368 F: 'static + Fn(TypedEnvelope<T>) -> Fut,
369 {
370 if !self.handler_types.insert(TypeId::of::<T>()) {
371 panic!("duplicate handler type");
372 }
373
374 self.message_handlers
375 .push(Box::new(move |envelope, connection_id| {
376 if envelope.as_ref().map_or(false, T::matches_envelope) {
377 let envelope = Option::take(envelope).unwrap();
378 let message_id = envelope.id;
379 let future = handler(TypedEnvelope {
380 sender_id: connection_id,
381 original_sender_id: envelope.original_sender_id.map(PeerId),
382 message_id,
383 payload: T::from_envelope(envelope).unwrap(),
384 });
385 Some(
386 async move {
387 if let Err(error) = future.await {
388 log::error!(
389 "error handling message {} {}: {:?}",
390 T::NAME,
391 message_id,
392 error
393 );
394 }
395 }
396 .boxed_local(),
397 )
398 } else {
399 None
400 }
401 }));
402 }
403}
404
405impl<W, R> IOHandler<W, R>
406where
407 W: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
408 R: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
409{
410 pub async fn run(mut self) -> Result<()> {
411 loop {
412 let read_message = self.reader.read_message().fuse();
413 futures::pin_mut!(read_message);
414 loop {
415 futures::select_biased! {
416 incoming = read_message => match incoming {
417 Ok(incoming) => {
418 if self.incoming_tx.send(incoming).await.is_err() {
419 return Ok(());
420 }
421 break;
422 }
423 Err(error) => {
424 Err(error).context("received invalid RPC message")?;
425 }
426 },
427 outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
428 Some(outgoing) => {
429 if let Err(result) = self.writer.write_message(&outgoing).await {
430 Err(result).context("failed to write RPC message")?;
431 }
432 }
433 None => return Ok(()),
434 }
435 }
436 }
437 }
438 }
439
440 pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
441 let envelope = self.reader.read_message().await?;
442 let original_sender_id = envelope.original_sender_id;
443 let message_id = envelope.id;
444 let payload =
445 M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
446 Ok(TypedEnvelope {
447 sender_id: self.connection_id,
448 original_sender_id: original_sender_id.map(PeerId),
449 message_id,
450 payload,
451 })
452 }
453}
454
455impl<T> Clone for Receipt<T> {
456 fn clone(&self) -> Self {
457 Self {
458 sender_id: self.sender_id,
459 message_id: self.message_id,
460 payload_type: PhantomData,
461 }
462 }
463}
464
465impl<T> Copy for Receipt<T> {}
466
467impl fmt::Display for ConnectionId {
468 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
469 self.0.fmt(f)
470 }
471}
472
473impl fmt::Display for PeerId {
474 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
475 self.0.fmt(f)
476 }
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482 use crate::test;
483
484 #[test]
485 fn test_request_response() {
486 smol::block_on(async move {
487 // create 2 clients connected to 1 server
488 let server = Peer::new();
489 let client1 = Peer::new();
490 let client2 = Peer::new();
491
492 let mut router = Router::new();
493 router.add_message_handler({
494 let server = server.clone();
495 move |envelope: TypedEnvelope<proto::Auth>| {
496 let server = server.clone();
497 async move {
498 let receipt = envelope.receipt();
499 let message = envelope.payload;
500 server
501 .respond(
502 receipt,
503 match message.user_id {
504 1 => {
505 assert_eq!(message.access_token, "access-token-1");
506 proto::AuthResponse {
507 credentials_valid: true,
508 }
509 }
510 2 => {
511 assert_eq!(message.access_token, "access-token-2");
512 proto::AuthResponse {
513 credentials_valid: false,
514 }
515 }
516 _ => {
517 panic!("unexpected user id {}", message.user_id);
518 }
519 },
520 )
521 .await
522 }
523 }
524 });
525
526 router.add_message_handler({
527 let server = server.clone();
528 move |envelope: TypedEnvelope<proto::OpenBuffer>| {
529 let server = server.clone();
530 async move {
531 let receipt = envelope.receipt();
532 let message = envelope.payload;
533 server
534 .respond(
535 receipt,
536 match message.path.as_str() {
537 "path/one" => {
538 assert_eq!(message.worktree_id, 1);
539 proto::OpenBufferResponse {
540 buffer: Some(proto::Buffer {
541 id: 101,
542 content: "path/one content".to_string(),
543 history: vec![],
544 selections: vec![],
545 }),
546 }
547 }
548 "path/two" => {
549 assert_eq!(message.worktree_id, 2);
550 proto::OpenBufferResponse {
551 buffer: Some(proto::Buffer {
552 id: 102,
553 content: "path/two content".to_string(),
554 history: vec![],
555 selections: vec![],
556 }),
557 }
558 }
559 _ => {
560 panic!("unexpected path {}", message.path);
561 }
562 },
563 )
564 .await
565 }
566 }
567 });
568 let router = Arc::new(router);
569
570 let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
571 let (client1_conn_id, io_task1, msg_task1) = client1
572 .add_connection(client1_to_server_conn, router.clone())
573 .await;
574 let (_, io_task2, msg_task2) = server
575 .add_connection(server_to_client_1_conn, router.clone())
576 .await;
577
578 let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
579 let (client2_conn_id, io_task3, msg_task3) = client2
580 .add_connection(client2_to_server_conn, router.clone())
581 .await;
582 let (_, io_task4, msg_task4) = server
583 .add_connection(server_to_client_2_conn, router.clone())
584 .await;
585
586 smol::spawn(io_task1.run()).detach();
587 smol::spawn(io_task2.run()).detach();
588 smol::spawn(io_task3.run()).detach();
589 smol::spawn(io_task4.run()).detach();
590 smol::spawn(msg_task1).detach();
591 smol::spawn(msg_task2).detach();
592 smol::spawn(msg_task3).detach();
593 smol::spawn(msg_task4).detach();
594
595 assert_eq!(
596 client1
597 .request(
598 client1_conn_id,
599 proto::Auth {
600 user_id: 1,
601 access_token: "access-token-1".to_string(),
602 },
603 )
604 .await
605 .unwrap(),
606 proto::AuthResponse {
607 credentials_valid: true,
608 }
609 );
610
611 assert_eq!(
612 client2
613 .request(
614 client2_conn_id,
615 proto::Auth {
616 user_id: 2,
617 access_token: "access-token-2".to_string(),
618 },
619 )
620 .await
621 .unwrap(),
622 proto::AuthResponse {
623 credentials_valid: false,
624 }
625 );
626
627 assert_eq!(
628 client1
629 .request(
630 client1_conn_id,
631 proto::OpenBuffer {
632 worktree_id: 1,
633 path: "path/one".to_string(),
634 },
635 )
636 .await
637 .unwrap(),
638 proto::OpenBufferResponse {
639 buffer: Some(proto::Buffer {
640 id: 101,
641 content: "path/one content".to_string(),
642 history: vec![],
643 selections: vec![],
644 }),
645 }
646 );
647
648 assert_eq!(
649 client2
650 .request(
651 client2_conn_id,
652 proto::OpenBuffer {
653 worktree_id: 2,
654 path: "path/two".to_string(),
655 },
656 )
657 .await
658 .unwrap(),
659 proto::OpenBufferResponse {
660 buffer: Some(proto::Buffer {
661 id: 102,
662 content: "path/two content".to_string(),
663 history: vec![],
664 selections: vec![],
665 }),
666 }
667 );
668
669 client1.disconnect(client1_conn_id).await;
670 client2.disconnect(client1_conn_id).await;
671 });
672 }
673
674 #[test]
675 fn test_disconnect() {
676 smol::block_on(async move {
677 let (client_conn, mut server_conn) = test::Channel::bidirectional();
678
679 let client = Peer::new();
680 let router = Arc::new(Router::new());
681 let (connection_id, io_handler, message_handler) =
682 client.add_connection(client_conn, router).await;
683
684 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
685 smol::spawn(async move {
686 io_handler.run().await.ok();
687 io_ended_tx.send(()).await.unwrap();
688 })
689 .detach();
690
691 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
692 smol::spawn(async move {
693 message_handler.await.ok();
694 messages_ended_tx.send(()).await.unwrap();
695 })
696 .detach();
697
698 client.disconnect(connection_id).await;
699
700 io_ended_rx.recv().await;
701 messages_ended_rx.recv().await;
702 assert!(
703 futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
704 .await
705 .is_err()
706 );
707 });
708 }
709
710 #[test]
711 fn test_io_error() {
712 smol::block_on(async move {
713 let (client_conn, server_conn) = test::Channel::bidirectional();
714 drop(server_conn);
715
716 let client = Peer::new();
717 let router = Arc::new(Router::new());
718 let (connection_id, io_handler, message_handler) =
719 client.add_connection(client_conn, router).await;
720 smol::spawn(io_handler.run()).detach();
721 smol::spawn(message_handler).detach();
722
723 let err = client
724 .request(
725 connection_id,
726 proto::Auth {
727 user_id: 42,
728 access_token: "token".to_string(),
729 },
730 )
731 .await
732 .unwrap_err();
733 assert_eq!(err.to_string(), "connection was closed");
734 });
735 }
736}