1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Context, Result};
3use async_lock::{Mutex, RwLock};
4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
5use futures::{
6 future::{BoxFuture, LocalBoxFuture},
7 FutureExt, StreamExt,
8};
9use postage::{
10 mpsc,
11 prelude::{Sink, Stream},
12};
13use std::{
14 any::TypeId,
15 collections::{HashMap, HashSet},
16 fmt,
17 future::Future,
18 marker::PhantomData,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23};
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
26pub struct ConnectionId(pub u32);
27
28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct PeerId(pub u32);
30
31type MessageHandler = Box<
32 dyn Send
33 + Sync
34 + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<'static, ()>>,
35>;
36
37type ForegroundMessageHandler =
38 Box<dyn Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<LocalBoxFuture<'static, ()>>>;
39
40pub struct Receipt<T> {
41 sender_id: ConnectionId,
42 message_id: u32,
43 payload_type: PhantomData<T>,
44}
45
46pub struct TypedEnvelope<T> {
47 pub sender_id: ConnectionId,
48 original_sender_id: Option<PeerId>,
49 pub message_id: u32,
50 pub payload: T,
51}
52
53impl<T> TypedEnvelope<T> {
54 pub fn original_sender_id(&self) -> Result<PeerId> {
55 self.original_sender_id
56 .ok_or_else(|| anyhow!("missing original_sender_id"))
57 }
58}
59
60impl<T: RequestMessage> TypedEnvelope<T> {
61 pub fn receipt(&self) -> Receipt<T> {
62 Receipt {
63 sender_id: self.sender_id,
64 message_id: self.message_id,
65 payload_type: PhantomData,
66 }
67 }
68}
69
70pub type Router = RouterInternal<MessageHandler>;
71pub type ForegroundRouter = RouterInternal<ForegroundMessageHandler>;
72pub struct RouterInternal<H> {
73 message_handlers: Vec<H>,
74 handler_types: HashSet<TypeId>,
75}
76
77pub struct Peer {
78 connections: RwLock<HashMap<ConnectionId, Connection>>,
79 next_connection_id: AtomicU32,
80}
81
82#[derive(Clone)]
83struct Connection {
84 outgoing_tx: mpsc::Sender<proto::Envelope>,
85 next_message_id: Arc<AtomicU32>,
86 response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
87}
88
89impl Peer {
90 pub fn new() -> Arc<Self> {
91 Arc::new(Self {
92 connections: Default::default(),
93 next_connection_id: Default::default(),
94 })
95 }
96
97 pub async fn add_connection<Conn, H, Fut>(
98 self: &Arc<Self>,
99 conn: Conn,
100 router: Arc<RouterInternal<H>>,
101 ) -> (
102 ConnectionId,
103 impl Future<Output = anyhow::Result<()>> + Send,
104 impl Future<Output = ()>,
105 )
106 where
107 H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
108 Fut: Future<Output = ()>,
109 Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
110 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
111 + Send
112 + Unpin,
113 {
114 let (tx, rx) = conn.split();
115 let connection_id = ConnectionId(
116 self.next_connection_id
117 .fetch_add(1, atomic::Ordering::SeqCst),
118 );
119 let (mut incoming_tx, mut incoming_rx) = mpsc::channel(64);
120 let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
121 let connection = Connection {
122 outgoing_tx,
123 next_message_id: Default::default(),
124 response_channels: Default::default(),
125 };
126 let mut writer = MessageStream::new(tx);
127 let mut reader = MessageStream::new(rx);
128
129 let handle_io = async move {
130 loop {
131 let read_message = reader.read_message().fuse();
132 futures::pin_mut!(read_message);
133 loop {
134 futures::select_biased! {
135 incoming = read_message => match incoming {
136 Ok(incoming) => {
137 if incoming_tx.send(incoming).await.is_err() {
138 return Ok(());
139 }
140 break;
141 }
142 Err(error) => {
143 Err(error).context("received invalid RPC message")?;
144 }
145 },
146 outgoing = outgoing_rx.recv().fuse() => match outgoing {
147 Some(outgoing) => {
148 if let Err(result) = writer.write_message(&outgoing).await {
149 Err(result).context("failed to write RPC message")?;
150 }
151 }
152 None => return Ok(()),
153 }
154 }
155 }
156 }
157 };
158
159 let response_channels = connection.response_channels.clone();
160 let handle_messages = async move {
161 while let Some(message) = incoming_rx.recv().await {
162 if let Some(responding_to) = message.responding_to {
163 let channel = response_channels.lock().await.remove(&responding_to);
164 if let Some(mut tx) = channel {
165 tx.send(message).await.ok();
166 } else {
167 log::warn!("received RPC response to unknown request {}", responding_to);
168 }
169 } else {
170 router.handle(connection_id, message).await;
171 }
172 }
173 response_channels.lock().await.clear();
174 };
175
176 self.connections
177 .write()
178 .await
179 .insert(connection_id, connection);
180
181 (connection_id, handle_io, handle_messages)
182 }
183
184 pub async fn disconnect(&self, connection_id: ConnectionId) {
185 self.connections.write().await.remove(&connection_id);
186 }
187
188 pub async fn reset(&self) {
189 self.connections.write().await.clear();
190 }
191
192 pub fn request<T: RequestMessage>(
193 self: &Arc<Self>,
194 receiver_id: ConnectionId,
195 request: T,
196 ) -> impl Future<Output = Result<T::Response>> {
197 self.request_internal(None, receiver_id, request)
198 }
199
200 pub fn forward_request<T: RequestMessage>(
201 self: &Arc<Self>,
202 sender_id: ConnectionId,
203 receiver_id: ConnectionId,
204 request: T,
205 ) -> impl Future<Output = Result<T::Response>> {
206 self.request_internal(Some(sender_id), receiver_id, request)
207 }
208
209 pub fn request_internal<T: RequestMessage>(
210 self: &Arc<Self>,
211 original_sender_id: Option<ConnectionId>,
212 receiver_id: ConnectionId,
213 request: T,
214 ) -> impl Future<Output = Result<T::Response>> {
215 let this = self.clone();
216 let (tx, mut rx) = mpsc::channel(1);
217 async move {
218 let mut connection = this.connection(receiver_id).await?;
219 let message_id = connection
220 .next_message_id
221 .fetch_add(1, atomic::Ordering::SeqCst);
222 connection
223 .response_channels
224 .lock()
225 .await
226 .insert(message_id, tx);
227 connection
228 .outgoing_tx
229 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
230 .await?;
231 let response = rx
232 .recv()
233 .await
234 .ok_or_else(|| anyhow!("connection was closed"))?;
235 T::Response::from_envelope(response)
236 .ok_or_else(|| anyhow!("received response of the wrong type"))
237 }
238 }
239
240 pub fn send<T: EnvelopedMessage>(
241 self: &Arc<Self>,
242 receiver_id: ConnectionId,
243 message: T,
244 ) -> impl Future<Output = Result<()>> {
245 let this = self.clone();
246 async move {
247 let mut connection = this.connection(receiver_id).await?;
248 let message_id = connection
249 .next_message_id
250 .fetch_add(1, atomic::Ordering::SeqCst);
251 connection
252 .outgoing_tx
253 .send(message.into_envelope(message_id, None, None))
254 .await?;
255 Ok(())
256 }
257 }
258
259 pub fn forward_send<T: EnvelopedMessage>(
260 self: &Arc<Self>,
261 sender_id: ConnectionId,
262 receiver_id: ConnectionId,
263 message: T,
264 ) -> impl Future<Output = Result<()>> {
265 let this = self.clone();
266 async move {
267 let mut connection = this.connection(receiver_id).await?;
268 let message_id = connection
269 .next_message_id
270 .fetch_add(1, atomic::Ordering::SeqCst);
271 connection
272 .outgoing_tx
273 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
274 .await?;
275 Ok(())
276 }
277 }
278
279 pub fn respond<T: RequestMessage>(
280 self: &Arc<Self>,
281 receipt: Receipt<T>,
282 response: T::Response,
283 ) -> impl Future<Output = Result<()>> {
284 let this = self.clone();
285 async move {
286 let mut connection = this.connection(receipt.sender_id).await?;
287 let message_id = connection
288 .next_message_id
289 .fetch_add(1, atomic::Ordering::SeqCst);
290 connection
291 .outgoing_tx
292 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
293 .await?;
294 Ok(())
295 }
296 }
297
298 fn connection(
299 self: &Arc<Self>,
300 connection_id: ConnectionId,
301 ) -> impl Future<Output = Result<Connection>> {
302 let this = self.clone();
303 async move {
304 let connections = this.connections.read().await;
305 let connection = connections
306 .get(&connection_id)
307 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
308 Ok(connection.clone())
309 }
310 }
311}
312
313impl<H, Fut> RouterInternal<H>
314where
315 H: Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<Fut>,
316 Fut: Future<Output = ()>,
317{
318 pub fn new() -> Self {
319 Self {
320 message_handlers: Default::default(),
321 handler_types: Default::default(),
322 }
323 }
324
325 async fn handle(&self, connection_id: ConnectionId, message: proto::Envelope) {
326 let mut envelope = Some(message);
327 for handler in self.message_handlers.iter() {
328 if let Some(future) = handler(&mut envelope, connection_id) {
329 future.await;
330 return;
331 }
332 }
333 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
334 }
335}
336
337impl Router {
338 pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
339 where
340 T: EnvelopedMessage,
341 Fut: 'static + Send + Future<Output = Result<()>>,
342 F: 'static + Send + Sync + Fn(TypedEnvelope<T>) -> Fut,
343 {
344 if !self.handler_types.insert(TypeId::of::<T>()) {
345 panic!("duplicate handler type");
346 }
347
348 self.message_handlers
349 .push(Box::new(move |envelope, connection_id| {
350 if envelope.as_ref().map_or(false, T::matches_envelope) {
351 let envelope = Option::take(envelope).unwrap();
352 let message_id = envelope.id;
353 let future = handler(TypedEnvelope {
354 sender_id: connection_id,
355 original_sender_id: envelope.original_sender_id.map(PeerId),
356 message_id,
357 payload: T::from_envelope(envelope).unwrap(),
358 });
359 Some(
360 async move {
361 if let Err(error) = future.await {
362 log::error!(
363 "error handling message {} {}: {:?}",
364 T::NAME,
365 message_id,
366 error
367 );
368 }
369 }
370 .boxed(),
371 )
372 } else {
373 None
374 }
375 }));
376 }
377}
378
379impl ForegroundRouter {
380 pub fn add_message_handler<T, Fut, F>(&mut self, handler: F)
381 where
382 T: EnvelopedMessage,
383 Fut: 'static + Future<Output = Result<()>>,
384 F: 'static + Fn(TypedEnvelope<T>) -> Fut,
385 {
386 if !self.handler_types.insert(TypeId::of::<T>()) {
387 panic!("duplicate handler type");
388 }
389
390 self.message_handlers
391 .push(Box::new(move |envelope, connection_id| {
392 if envelope.as_ref().map_or(false, T::matches_envelope) {
393 let envelope = Option::take(envelope).unwrap();
394 let message_id = envelope.id;
395 let future = handler(TypedEnvelope {
396 sender_id: connection_id,
397 original_sender_id: envelope.original_sender_id.map(PeerId),
398 message_id,
399 payload: T::from_envelope(envelope).unwrap(),
400 });
401 Some(
402 async move {
403 if let Err(error) = future.await {
404 log::error!(
405 "error handling message {} {}: {:?}",
406 T::NAME,
407 message_id,
408 error
409 );
410 }
411 }
412 .boxed_local(),
413 )
414 } else {
415 None
416 }
417 }));
418 }
419}
420
421impl<T> Clone for Receipt<T> {
422 fn clone(&self) -> Self {
423 Self {
424 sender_id: self.sender_id,
425 message_id: self.message_id,
426 payload_type: PhantomData,
427 }
428 }
429}
430
431impl<T> Copy for Receipt<T> {}
432
433impl fmt::Display for ConnectionId {
434 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
435 self.0.fmt(f)
436 }
437}
438
439impl fmt::Display for PeerId {
440 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441 self.0.fmt(f)
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448 use crate::test;
449
450 #[test]
451 fn test_request_response() {
452 smol::block_on(async move {
453 // create 2 clients connected to 1 server
454 let server = Peer::new();
455 let client1 = Peer::new();
456 let client2 = Peer::new();
457
458 let mut router = Router::new();
459 router.add_message_handler({
460 let server = server.clone();
461 move |envelope: TypedEnvelope<proto::Auth>| {
462 let server = server.clone();
463 async move {
464 let receipt = envelope.receipt();
465 let message = envelope.payload;
466 server
467 .respond(
468 receipt,
469 match message.user_id {
470 1 => {
471 assert_eq!(message.access_token, "access-token-1");
472 proto::AuthResponse {
473 credentials_valid: true,
474 }
475 }
476 2 => {
477 assert_eq!(message.access_token, "access-token-2");
478 proto::AuthResponse {
479 credentials_valid: false,
480 }
481 }
482 _ => {
483 panic!("unexpected user id {}", message.user_id);
484 }
485 },
486 )
487 .await
488 }
489 }
490 });
491
492 router.add_message_handler({
493 let server = server.clone();
494 move |envelope: TypedEnvelope<proto::OpenBuffer>| {
495 let server = server.clone();
496 async move {
497 let receipt = envelope.receipt();
498 let message = envelope.payload;
499 server
500 .respond(
501 receipt,
502 match message.path.as_str() {
503 "path/one" => {
504 assert_eq!(message.worktree_id, 1);
505 proto::OpenBufferResponse {
506 buffer: Some(proto::Buffer {
507 id: 101,
508 content: "path/one content".to_string(),
509 history: vec![],
510 selections: vec![],
511 }),
512 }
513 }
514 "path/two" => {
515 assert_eq!(message.worktree_id, 2);
516 proto::OpenBufferResponse {
517 buffer: Some(proto::Buffer {
518 id: 102,
519 content: "path/two content".to_string(),
520 history: vec![],
521 selections: vec![],
522 }),
523 }
524 }
525 _ => {
526 panic!("unexpected path {}", message.path);
527 }
528 },
529 )
530 .await
531 }
532 }
533 });
534 let router = Arc::new(router);
535
536 let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
537 let (client1_conn_id, io_task1, msg_task1) = client1
538 .add_connection(client1_to_server_conn, router.clone())
539 .await;
540 let (_, io_task2, msg_task2) = server
541 .add_connection(server_to_client_1_conn, router.clone())
542 .await;
543
544 let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
545 let (client2_conn_id, io_task3, msg_task3) = client2
546 .add_connection(client2_to_server_conn, router.clone())
547 .await;
548 let (_, io_task4, msg_task4) = server
549 .add_connection(server_to_client_2_conn, router.clone())
550 .await;
551
552 smol::spawn(io_task1).detach();
553 smol::spawn(io_task2).detach();
554 smol::spawn(io_task3).detach();
555 smol::spawn(io_task4).detach();
556 smol::spawn(msg_task1).detach();
557 smol::spawn(msg_task2).detach();
558 smol::spawn(msg_task3).detach();
559 smol::spawn(msg_task4).detach();
560
561 assert_eq!(
562 client1
563 .request(
564 client1_conn_id,
565 proto::Auth {
566 user_id: 1,
567 access_token: "access-token-1".to_string(),
568 },
569 )
570 .await
571 .unwrap(),
572 proto::AuthResponse {
573 credentials_valid: true,
574 }
575 );
576
577 assert_eq!(
578 client2
579 .request(
580 client2_conn_id,
581 proto::Auth {
582 user_id: 2,
583 access_token: "access-token-2".to_string(),
584 },
585 )
586 .await
587 .unwrap(),
588 proto::AuthResponse {
589 credentials_valid: false,
590 }
591 );
592
593 assert_eq!(
594 client1
595 .request(
596 client1_conn_id,
597 proto::OpenBuffer {
598 worktree_id: 1,
599 path: "path/one".to_string(),
600 },
601 )
602 .await
603 .unwrap(),
604 proto::OpenBufferResponse {
605 buffer: Some(proto::Buffer {
606 id: 101,
607 content: "path/one content".to_string(),
608 history: vec![],
609 selections: vec![],
610 }),
611 }
612 );
613
614 assert_eq!(
615 client2
616 .request(
617 client2_conn_id,
618 proto::OpenBuffer {
619 worktree_id: 2,
620 path: "path/two".to_string(),
621 },
622 )
623 .await
624 .unwrap(),
625 proto::OpenBufferResponse {
626 buffer: Some(proto::Buffer {
627 id: 102,
628 content: "path/two content".to_string(),
629 history: vec![],
630 selections: vec![],
631 }),
632 }
633 );
634
635 client1.disconnect(client1_conn_id).await;
636 client2.disconnect(client1_conn_id).await;
637 });
638 }
639
640 #[test]
641 fn test_disconnect() {
642 smol::block_on(async move {
643 let (client_conn, mut server_conn) = test::Channel::bidirectional();
644
645 let client = Peer::new();
646 let router = Arc::new(Router::new());
647 let (connection_id, io_handler, message_handler) =
648 client.add_connection(client_conn, router).await;
649
650 let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
651 smol::spawn(async move {
652 io_handler.await.ok();
653 io_ended_tx.send(()).await.unwrap();
654 })
655 .detach();
656
657 let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
658 smol::spawn(async move {
659 message_handler.await;
660 messages_ended_tx.send(()).await.unwrap();
661 })
662 .detach();
663
664 client.disconnect(connection_id).await;
665
666 io_ended_rx.recv().await;
667 messages_ended_rx.recv().await;
668 assert!(
669 futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
670 .await
671 .is_err()
672 );
673 });
674 }
675
676 #[test]
677 fn test_io_error() {
678 smol::block_on(async move {
679 let (client_conn, server_conn) = test::Channel::bidirectional();
680 drop(server_conn);
681
682 let client = Peer::new();
683 let router = Arc::new(Router::new());
684 let (connection_id, io_handler, message_handler) =
685 client.add_connection(client_conn, router).await;
686 smol::spawn(io_handler).detach();
687 smol::spawn(message_handler).detach();
688
689 let err = client
690 .request(
691 connection_id,
692 proto::Auth {
693 user_id: 42,
694 access_token: "token".to_string(),
695 },
696 )
697 .await
698 .unwrap_err();
699 assert_eq!(err.to_string(), "connection was closed");
700 });
701 }
702}