1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Context, Result};
3use async_lock::{Mutex, RwLock};
4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
5use futures::{
6 future::{BoxFuture, LocalBoxFuture},
7 FutureExt, StreamExt,
8};
9use postage::{
10 mpsc,
11 prelude::{Sink, Stream},
12};
13use std::{
14 any::TypeId,
15 collections::{HashMap, HashSet},
16 fmt,
17 future::Future,
18 marker::PhantomData,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct PeerId(pub u32);
30
31type MessageHandler = Box<
32 dyn Send
33 + Sync
34 + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
35>;
36
37type ForegroundMessageHandler =
38 Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
39
40pub struct Receipt<T> {
41 sender_id: ConnectionId,
42 message_id: u32,
43 payload_type: PhantomData<T>,
44}
45
46pub struct TypedEnvelope<T> {
47 pub sender_id: ConnectionId,
48 original_sender_id: Option<PeerId>,
49 pub message_id: u32,
50 pub payload: T,
51}
52
53impl<T> TypedEnvelope<T> {
54 pub fn original_sender_id(&self) -> Result<PeerId> {
55 self.original_sender_id
56 .ok_or_else(|| anyhow!("missing original_sender_id"))
57 }
58}
59
60impl<T: RequestMessage> TypedEnvelope<T> {
61 pub fn receipt(&self) -> Receipt<T> {
62 Receipt {
63 sender_id: self.sender_id,
64 message_id: self.message_id,
65 payload_type: PhantomData,
66 }
67 }
68}
69
70pub type Router = RouterInternal<MessageHandler>;
71pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
72pub struct RouterInternal<H> {
73 message_handlers: Vec<H>,
74 handler_types: HashSet<TypeId>,
75}
76
77pub struct Peer {
78 connections: RwLock<HashMap<ConnectionId, Connection>>,
79 next_connection_id: AtomicU32,
80}
81
82#[derive(Clone)]
83struct Connection {
84 outgoing_tx: mpsc::Sender<proto::Envelope>,
85 next_message_id: Arc<AtomicU32>,
86 response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
87}
88
89impl Peer {
90 pub fn new() -> Arc<Self> {
91 Arc::new(Self {
92 connections: Default::default(),
93 next_connection_id: Default::default(),
94 })
95 }
96
97 pub async fn add_connection<Conn, H, Fut>(
98 self: &Arc<Self>,
99 conn: Conn,
100 router: Arc<RouterInternal<H>>,
101 ) -> (
102 ConnectionId,
103 impl Future<Output = anyhow::Result<()>> + Send,
104 impl Future<Output = ()>,
105 )
106 where
107 H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
108 Fut: Future<Output = ()>,
109 Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
110 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
111 + Send
112 + Unpin,
113 {
114 let (tx, rx) = conn.split();
115 let connection_id = ConnectionId(
116 self.next_connection_id
117 .fetch_add(1, atomic::Ordering::SeqCst),
118 );
119 let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64);
120 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
121 let connection = Connection {
122 outgoing_tx,
123 next_message_id: Default::default(),
124 response_channels: Default::default(),
125 };
126 let mut writer = MessageStream::new(tx);
127 let mut reader = MessageStream::new(rx);
128
129 let handle_io = async move {
130 loop {
131 let read_message = reader.read_message().fuse();
132 futures::pin_mut!(read_message);
133 loop {
134 futures::select_biased! {
135 incoming = read_message => match incoming {
136 Ok(incoming) => {
137 if incoming_tx.send(incoming).await.is_err() {
138 return Ok(());
139 }
140 break;
141 }
142 Err(error) => {
143 Err(error).context("received invalid RPC message")?;
144 }
145 },
146 outgoing = outgoing_rx.recv().fuse() => match outgoing {
147 Some(outgoing) => {
148 if let Err(result) = writer.write_message(&outgoing).await {
149 Err(result).context("failed to write RPC message")?;
150 }
151 }
152 None => return Ok(()),
153 }
154 }
155 }
156 }
157 };
158
159 let response_channels = connection.response_channels.clone();
160 let handle_messages = async move {
161 while let Some(message) = incoming_rx.recv().await {
162 if let Some(responding_to) = message.responding_to {
163 let channel = response_channels.lock().await.remove(&responding_to);
164 if let Some(mut tx) = channel {
165 tx.send(message).await.ok();
166 } else {
167 log::warn!("received RPC response to unknown request {}", responding_to);
168 }
169 } else {
170 router.handle(connection_id, message).await;
171 }
172 }
173 response_channels.lock().await.clear();
174 };
175
176 self.connections
177 .write()
178 .await
179 .insert(connection_id, connection);
180
181 (connection_id, handle_io, handle_messages)
182 }
183
184 pub async fn disconnect(&self, connection_id: ConnectionId) {
185 self.connections.write().await.remove(&connection_id);
186 }
187
188 pub async fn reset(&self) {
189 self.connections.write().await.clear();
190 }
191
192 pub fn request<T: RequestMessage>(
193 self: &Arc<Self>,
194 receiver_id: ConnectionId,
195 request: T,
196 ) -> impl Future<Output = Result<T::Response>> {
197 self.request_internal(None, receiver_id, request)
198 }
199
200 pub fn forward_request<T: RequestMessage>(
201 self: &Arc<Self>,
202 sender_id: ConnectionId,
203 receiver_id: ConnectionId,
204 request: T,
205 ) -> impl Future<Output = Result<T::Response>> {
206 self.request_internal(Some(sender_id), receiver_id, request)
207 }
208
209 pub fn request_internal<T: RequestMessage>(
210 self: &Arc<Self>,
211 original_sender_id: Option<ConnectionId>,
212 receiver_id: ConnectionId,
213 request: T,
214 ) -> impl Future<Output = Result<T::Response>> {
215 let this = self.clone();
216 let (tx, mut rx) = mpsc::channel(1);
217 async move {
218 let mut connection = this.connection(receiver_id).await?;
219 let message_id = connection
220 .next_message_id
221 .fetch_add(1, atomic::Ordering::SeqCst);
222 connection
223 .response_channels
224 .lock()
225 .await
226 .insert(message_id, tx);
227 connection
228 .outgoing_tx
229 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
230 .await
231 .map_err(|_| anyhow!("connection was closed"))?;
232 let response = rx
233 .recv()
234 .await
235 .ok_or_else(|| anyhow!("connection was closed"))?;
236 T::Response::from_envelope(response)
237 .ok_or_else(|| anyhow!("received response of the wrong type"))
238 }
239 }
240
241 pub fn send<T: EnvelopedMessage>(
242 self: &Arc<Self>,
243 receiver_id: ConnectionId,
244 message: T,
245 ) -> impl Future<Output = Result<()>> {
246 let this = self.clone();
247 async move {
248 let mut connection = this.connection(receiver_id).await?;
249 let message_id = connection
250 .next_message_id
251 .fetch_add(1, atomic::Ordering::SeqCst);
252 connection
253 .outgoing_tx
254 .send(message.into_envelope(message_id, None, None))
255 .await?;
256 Ok(())
257 }
258 }
259
260 pub fn forward_send<T: EnvelopedMessage>(
261 self: &Arc<Self>,
262 sender_id: ConnectionId,
263 receiver_id: ConnectionId,
264 message: T,
265 ) -> impl Future<Output = Result<()>> {
266 let this = self.clone();
267 async move {
268 let mut connection = this.connection(receiver_id).await?;
269 let message_id = connection
270 .next_message_id
271 .fetch_add(1, atomic::Ordering::SeqCst);
272 connection
273 .outgoing_tx
274 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
275 .await?;
276 Ok(())
277 }
278 }
279
280 pub fn respond<T: RequestMessage>(
281 self: &Arc<Self>,
282 receipt: Receipt<T>,
283 response: T::Response,
284 ) -> impl Future<Output = Result<()>> {
285 let this = self.clone();
286 async move {
287 let mut connection = this.connection(receipt.sender_id).await?;
288 let message_id = connection
289 .next_message_id
290 .fetch_add(1, atomic::Ordering::SeqCst);
291 connection
292 .outgoing_tx
293 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
294 .await?;
295 Ok(())
296 }
297 }
298
299 fn connection(
300 self: &Arc<Self>,
301 connection_id: ConnectionId,
302 ) -> impl Future<Output = Result<Connection>> {
303 let this = self.clone();
304 async move {
305 let connections = this.connections.read().await;
306 let connection = connections
307 .get(&connection_id)
308 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
309 Ok(connection.clone())
310 }
311 }
312}
313
314impl<H, Fut> RouterInternal<H>
315where
316 H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
317 Fut: Future<Output = ()>,
318{
319 pub fn new() -> Self {
320 Self {
321 message_handlers: Default::default(),
322 handler_types: Default::default(),
323 }
324 }
325
326 async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
327 let mut envelope = Some(message);
328 for handler in self.message_handlers.iter() {
329 if let Some(future) = handler(&mut envelope, connection_id) {
330 future.await;
331 return;
332 }
333 }
334 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
335 }
336}
337
338impl Router {
339 pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
340 where
341 T: EnvelopedMessage,
342 Fut: 'static + Send + Future<Output = Result<()>>,
343 F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
344 {
345 if !self.handler_types.insert(TypeId::of::<T>()) {
346 panic!("duplicate handler type");
347 }
348
349 self.message_handlers
350 .push(Box::new(move |envelope, connection_id| {
351 if envelope.as_ref().map_or(false, T::matches_envelope) {
352 let envelope = Option::take(envelope).unwrap();
353 let message_id = envelope.id;
354 let future = handler(TypedEnvelope {
355 sender_id: connection_id,
356 original_sender_id: envelope.original_sender_id.map(PeerId),
357 message_id,
358 payload: T::from_envelope(envelope).unwrap(),
359 });
360 Some(
361 async move {
362 if let Err(error) = future.await {
363 log::error!(
364 "error handling message {} {}: {:?}",
365 T::NAME,
366 message_id,
367 error
368 );
369 }
370 }
371 .boxed(),
372 )
373 } else {
374 None
375 }
376 }));
377 }
378}
379
380impl ForegroundRouter {
381 pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
382 where
383 T: EnvelopedMessage,
384 Fut: 'static + Future<Output = Result<()>>,
385 F: 'static + Fn(TypedEnvelope<T>) -> Fut,
386 {
387 if !self.handler_types.insert(TypeId::of::<T>()) {
388 panic!("duplicate handler type");
389 }
390
391 self.message_handlers
392 .push(Box::new(move |envelope, connection_id| {
393 if envelope.as_ref().map_or(false, T::matches_envelope) {
394 let envelope = Option::take(envelope).unwrap();
395 let message_id = envelope.id;
396 let future = handler(TypedEnvelope {
397 sender_id: connection_id,
398 original_sender_id: envelope.original_sender_id.map(PeerId),
399 message_id,
400 payload: T::from_envelope(envelope).unwrap(),
401 });
402 Some(
403 async move {
404 if let Err(error) = future.await {
405 log::error!(
406 "error handling message {} {}: {:?}",
407 T::NAME,
408 message_id,
409 error
410 );
411 }
412 }
413 .boxed_local(),
414 )
415 } else {
416 None
417 }
418 }));
419 }
420}
421
422impl<T> Clone for Receipt<T> {
423 fn clone(&self) -> Self {
424 Self {
425 sender_id: self.sender_id,
426 message_id: self.message_id,
427 payload_type: PhantomData,
428 }
429 }
430}
431
432impl<T> Copy for Receipt<T> {}
433
434impl fmt::Display for ConnectionId {
435 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
436 self.0.fmt(f)
437 }
438}
439
440impl fmt::Display for PeerId {
441 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
442 self.0.fmt(f)
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449 use crate::test;
450
451 #[test]
452 fn test_request_response() {
453 smol::block_on(async move {
454 // create 2 clients connected to 1 server
455 let server = Peer::new();
456 let client1 = Peer::new();
457 let client2 = Peer::new();
458
459 let mut router = Router::new();
460 router.add_message_handler({
461 let server = server.clone();
462 move |envelope: TypedEnvelope<proto::Auth>| {
463 let server = server.clone();
464 async move {
465 let receipt = envelope.receipt();
466 let message = envelope.payload;
467 server
468 .respond(
469 receipt,
470 match message.user_id {
471 1 => {
472 assert_eq!(message.access_token, "access-token-1");
473 proto::AuthResponse {
474 credentials_valid: true,
475 }
476 }
477 2 => {
478 assert_eq!(message.access_token, "access-token-2");
479 proto::AuthResponse {
480 credentials_valid: false,
481 }
482 }
483 _ => {
484 panic!("unexpected user id {}", message.user_id);
485 }
486 },
487 )
488 .await
489 }
490 }
491 });
492
493 router.add_message_handler({
494 let server = server.clone();
495 move |envelope: TypedEnvelope<proto::OpenBuffer>| {
496 let server = server.clone();
497 async move {
498 let receipt = envelope.receipt();
499 let message = envelope.payload;
500 server
501 .respond(
502 receipt,
503 match message.path.as_str() {
504 "path/one" => {
505 assert_eq!(message.worktree_id, 1);
506 proto::OpenBufferResponse {
507 buffer: Some(proto::Buffer {
508 id: 101,
509 content: "path/one content".to_string(),
510 history: vec![],
511 selections: vec![],
512 }),
513 }
514 }
515 "path/two" => {
516 assert_eq!(message.worktree_id, 2);
517 proto::OpenBufferResponse {
518 buffer: Some(proto::Buffer {
519 id: 102,
520 content: "path/two content".to_string(),
521 history: vec![],
522 selections: vec![],
523 }),
524 }
525 }
526 _ => {
527 panic!("unexpected path {}", message.path);
528 }
529 },
530 )
531 .await
532 }
533 }
534 });
535 let router = Arc::new(router);
536
537 let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
538 let (client1_conn_id, io_task1, msg_task1) = client1
539 .add_connection(client1_to_server_conn, router.clone())
540 .await;
541 let (_, io_task2, msg_task2) = server
542 .add_connection(server_to_client_1_conn, router.clone())
543 .await;
544
545 let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
546 let (client2_conn_id, io_task3, msg_task3) = client2
547 .add_connection(client2_to_server_conn, router.clone())
548 .await;
549 let (_, io_task4, msg_task4) = server
550 .add_connection(server_to_client_2_conn, router.clone())
551 .await;
552
553 smol::spawn(io_task1).detach();
554 smol::spawn(io_task2).detach();
555 smol::spawn(io_task3).detach();
556 smol::spawn(io_task4).detach();
557 smol::spawn(msg_task1).detach();
558 smol::spawn(msg_task2).detach();
559 smol::spawn(msg_task3).detach();
560 smol::spawn(msg_task4).detach();
561
562 assert_eq!(
563 client1
564 .request(
565 client1_conn_id,
566 proto::Auth {
567 user_id: 1,
568 access_token: "access-token-1".to_string(),
569 },
570 )
571 .await
572 .unwrap(),
573 proto::AuthResponse {
574 credentials_valid: true,
575 }
576 );
577
578 assert_eq!(
579 client2
580 .request(
581 client2_conn_id,
582 proto::Auth {
583 user_id: 2,
584 access_token: "access-token-2".to_string(),
585 },
586 )
587 .await
588 .unwrap(),
589 proto::AuthResponse {
590 credentials_valid: false,
591 }
592 );
593
594 assert_eq!(
595 client1
596 .request(
597 client1_conn_id,
598 proto::OpenBuffer {
599 worktree_id: 1,
600 path: "path/one".to_string(),
601 },
602 )
603 .await
604 .unwrap(),
605 proto::OpenBufferResponse {
606 buffer: Some(proto::Buffer {
607 id: 101,
608 content: "path/one content".to_string(),
609 history: vec![],
610 selections: vec![],
611 }),
612 }
613 );
614
615 assert_eq!(
616 client2
617 .request(
618 client2_conn_id,
619 proto::OpenBuffer {
620 worktree_id: 2,
621 path: "path/two".to_string(),
622 },
623 )
624 .await
625 .unwrap(),
626 proto::OpenBufferResponse {
627 buffer: Some(proto::Buffer {
628 id: 102,
629 content: "path/two content".to_string(),
630 history: vec![],
631 selections: vec![],
632 }),
633 }
634 );
635
636 client1.disconnect(client1_conn_id).await;
637 client2.disconnect(client1_conn_id).await;
638 });
639 }
640
641 #[test]
642 fn test_disconnect() {
643 smol::block_on(async move {
644 let (client_conn, mut server_conn) = test::Channel::bidirectional();
645
646 let client = Peer::new();
647 let router = Arc::new(Router::new());
648 let (connection_id, io_handler, message_handler) =
649 client.add_connection(client_conn, router).await;
650
651 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
652 smol::spawn(async move {
653 io_handler.await.ok();
654 io_ended_tx.send(()).await.unwrap();
655 })
656 .detach();
657
658 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
659 smol::spawn(async move {
660 message_handler.await;
661 messages_ended_tx.send(()).await.unwrap();
662 })
663 .detach();
664
665 client.disconnect(connection_id).await;
666
667 io_ended_rx.recv().await;
668 messages_ended_rx.recv().await;
669 assert!(
670 futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
671 .await
672 .is_err()
673 );
674 });
675 }
676
677 #[test]
678 fn test_io_error() {
679 smol::block_on(async move {
680 let (client_conn, server_conn) = test::Channel::bidirectional();
681 drop(server_conn);
682
683 let client = Peer::new();
684 let router = Arc::new(Router::new());
685 let (connection_id, io_handler, message_handler) =
686 client.add_connection(client_conn, router).await;
687 smol::spawn(io_handler).detach();
688 smol::spawn(message_handler).detach();
689
690 let err = client
691 .request(
692 connection_id,
693 proto::Auth {
694 user_id: 42,
695 access_token: "token".to_string(),
696 },
697 )
698 .await
699 .unwrap_err();
700 assert_eq!(err.to_string(), "connection was closed");
701 });
702 }
703}