1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Context, Result};
3use async_lock::{Mutex, RwLock};
4use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
5use futures::{
6 future::BoxFuture,
7 stream::{SplitSink, SplitStream},
8 FutureExt, StreamExt,
9};
10use postage::{
11 mpsc,
12 prelude::{Sink, Stream},
13};
14use std::{
15 any::TypeId,
16 collections::{HashMap, HashSet},
17 fmt,
18 future::Future,
19 marker::PhantomData,
20 sync::{
21 atomic::{self, AtomicU32},
22 Arc,
23 },
24};
25
26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
27pub struct ConnectionId(pub u32);
28
29#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
30pub struct PeerId(pub u32);
31
32type MessageHandler = Box<
33 dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
34>;
35
36pub struct Receipt<T> {
37 sender_id: ConnectionId,
38 message_id: u32,
39 payload_type: PhantomData<T>,
40}
41
42pub struct TypedEnvelope<T> {
43 pub sender_id: ConnectionId,
44 original_sender_id: Option<PeerId>,
45 pub message_id: u32,
46 pub payload: T,
47}
48
49impl<T> TypedEnvelope<T> {
50 pub fn original_sender_id(&self) -> Result<PeerId> {
51 self.original_sender_id
52 .ok_or_else(|| anyhow!("missing original_sender_id"))
53 }
54}
55
56impl<T: RequestMessage> TypedEnvelope<T> {
57 pub fn receipt(&self) -> Receipt<T> {
58 Receipt {
59 sender_id: self.sender_id,
60 message_id: self.message_id,
61 payload_type: PhantomData,
62 }
63 }
64}
65
66pub struct Peer {
67 connections: RwLock<HashMap<ConnectionId, Connection>>,
68 message_handlers: RwLock<Vec<MessageHandler>>,
69 handler_types: Mutex<HashSet<TypeId>>,
70 next_connection_id: AtomicU32,
71}
72
73#[derive(Clone)]
74struct Connection {
75 outgoing_tx: mpsc::Sender<proto::Envelope>,
76 next_message_id: Arc<AtomicU32>,
77 response_channels: ResponseChannels,
78}
79
80pub struct ConnectionHandler<W, R> {
81 peer: Arc<Peer>,
82 connection_id: ConnectionId,
83 response_channels: ResponseChannels,
84 outgoing_rx: mpsc::Receiver<proto::Envelope>,
85 writer: MessageStream<W>,
86 reader: MessageStream<R>,
87}
88
89type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
90
91impl Peer {
92 pub fn new() -> Arc<Self> {
93 Arc::new(Self {
94 connections: Default::default(),
95 message_handlers: Default::default(),
96 handler_types: Default::default(),
97 next_connection_id: Default::default(),
98 })
99 }
100
101 pub async fn add_message_handler<T: EnvelopedMessage>(
102 &self,
103 ) -> mpsc::Receiver<TypedEnvelope<T>> {
104 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
105 panic!("duplicate handler type");
106 }
107
108 let (tx, rx) = mpsc::channel(256);
109 self.message_handlers
110 .write()
111 .await
112 .push(Box::new(move |envelope, connection_id| {
113 if envelope.as_ref().map_or(false, T::matches_envelope) {
114 let envelope = Option::take(envelope).unwrap();
115 let mut tx = tx.clone();
116 Some(
117 async move {
118 tx.send(TypedEnvelope {
119 sender_id: connection_id,
120 original_sender_id: envelope.original_sender_id.map(PeerId),
121 message_id: envelope.id,
122 payload: T::from_envelope(envelope).unwrap(),
123 })
124 .await
125 .is_err()
126 }
127 .boxed(),
128 )
129 } else {
130 None
131 }
132 }));
133 rx
134 }
135
136 pub async fn add_connection<Conn>(
137 self: &Arc<Self>,
138 conn: Conn,
139 ) -> (
140 ConnectionId,
141 ConnectionHandler<SplitSink<Conn, WebSocketMessage>, SplitStream<Conn>>,
142 )
143 where
144 Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
145 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
146 + Unpin,
147 {
148 let (tx, rx) = conn.split();
149 let connection_id = ConnectionId(
150 self.next_connection_id
151 .fetch_add(1, atomic::Ordering::SeqCst),
152 );
153 let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
154 let connection = Connection {
155 outgoing_tx,
156 next_message_id: Default::default(),
157 response_channels: Default::default(),
158 };
159 let handler = ConnectionHandler {
160 peer: self.clone(),
161 connection_id,
162 response_channels: connection.response_channels.clone(),
163 outgoing_rx,
164 writer: MessageStream::new(tx),
165 reader: MessageStream::new(rx),
166 };
167 self.connections
168 .write()
169 .await
170 .insert(connection_id, connection);
171 (connection_id, handler)
172 }
173
174 pub async fn disconnect(&self, connection_id: ConnectionId) {
175 self.connections.write().await.remove(&connection_id);
176 }
177
178 pub async fn reset(&self) {
179 self.connections.write().await.clear();
180 self.handler_types.lock().await.clear();
181 self.message_handlers.write().await.clear();
182 }
183
184 pub fn request<T: RequestMessage>(
185 self: &Arc<Self>,
186 receiver_id: ConnectionId,
187 request: T,
188 ) -> impl Future<Output = Result<T::Response>> {
189 self.request_internal(None, receiver_id, request)
190 }
191
192 pub fn forward_request<T: RequestMessage>(
193 self: &Arc<Self>,
194 sender_id: ConnectionId,
195 receiver_id: ConnectionId,
196 request: T,
197 ) -> impl Future<Output = Result<T::Response>> {
198 self.request_internal(Some(sender_id), receiver_id, request)
199 }
200
201 pub fn request_internal<T: RequestMessage>(
202 self: &Arc<Self>,
203 original_sender_id: Option<ConnectionId>,
204 receiver_id: ConnectionId,
205 request: T,
206 ) -> impl Future<Output = Result<T::Response>> {
207 let this = self.clone();
208 let (tx, mut rx) = mpsc::channel(1);
209 async move {
210 let mut connection = this.connection(receiver_id).await?;
211 let message_id = connection
212 .next_message_id
213 .fetch_add(1, atomic::Ordering::SeqCst);
214 connection
215 .response_channels
216 .lock()
217 .await
218 .insert(message_id, tx);
219 connection
220 .outgoing_tx
221 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
222 .await?;
223 let response = rx
224 .recv()
225 .await
226 .ok_or_else(|| anyhow!("connection was closed"))?;
227 T::Response::from_envelope(response)
228 .ok_or_else(|| anyhow!("received response of the wrong type"))
229 }
230 }
231
232 pub fn send<T: EnvelopedMessage>(
233 self: &Arc<Self>,
234 receiver_id: ConnectionId,
235 message: T,
236 ) -> impl Future<Output = Result<()>> {
237 let this = self.clone();
238 async move {
239 let mut connection = this.connection(receiver_id).await?;
240 let message_id = connection
241 .next_message_id
242 .fetch_add(1, atomic::Ordering::SeqCst);
243 connection
244 .outgoing_tx
245 .send(message.into_envelope(message_id, None, None))
246 .await?;
247 Ok(())
248 }
249 }
250
251 pub fn forward_send<T: EnvelopedMessage>(
252 self: &Arc<Self>,
253 sender_id: ConnectionId,
254 receiver_id: ConnectionId,
255 message: T,
256 ) -> impl Future<Output = Result<()>> {
257 let this = self.clone();
258 async move {
259 let mut connection = this.connection(receiver_id).await?;
260 let message_id = connection
261 .next_message_id
262 .fetch_add(1, atomic::Ordering::SeqCst);
263 connection
264 .outgoing_tx
265 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
266 .await?;
267 Ok(())
268 }
269 }
270
271 pub fn respond<T: RequestMessage>(
272 self: &Arc<Self>,
273 receipt: Receipt<T>,
274 response: T::Response,
275 ) -> impl Future<Output = Result<()>> {
276 let this = self.clone();
277 async move {
278 let mut connection = this.connection(receipt.sender_id).await?;
279 let message_id = connection
280 .next_message_id
281 .fetch_add(1, atomic::Ordering::SeqCst);
282 connection
283 .outgoing_tx
284 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
285 .await?;
286 Ok(())
287 }
288 }
289
290 fn connection(
291 self: &Arc<Self>,
292 connection_id: ConnectionId,
293 ) -> impl Future<Output = Result<Connection>> {
294 let this = self.clone();
295 async move {
296 let connections = this.connections.read().await;
297 let connection = connections
298 .get(&connection_id)
299 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
300 Ok(connection.clone())
301 }
302 }
303}
304
305impl<W, R> ConnectionHandler<W, R>
306where
307 W: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
308 R: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
309{
310 pub async fn run(mut self) -> Result<()> {
311 loop {
312 let read_message = self.reader.read_message().fuse();
313 futures::pin_mut!(read_message);
314 loop {
315 futures::select_biased! {
316 incoming = read_message => match incoming {
317 Ok(incoming) => {
318 Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
319 break;
320 }
321 Err(error) => {
322 self.response_channels.lock().await.clear();
323 Err(error).context("received invalid RPC message")?;
324 }
325 },
326 outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
327 Some(outgoing) => {
328 if let Err(result) = self.writer.write_message(&outgoing).await {
329 self.response_channels.lock().await.clear();
330 Err(result).context("failed to write RPC message")?;
331 }
332 }
333 None => return Ok(()),
334 }
335 }
336 }
337 }
338 }
339
340 pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
341 let envelope = self.reader.read_message().await?;
342 let original_sender_id = envelope.original_sender_id;
343 let message_id = envelope.id;
344 let payload =
345 M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
346 Ok(TypedEnvelope {
347 sender_id: self.connection_id,
348 original_sender_id: original_sender_id.map(PeerId),
349 message_id,
350 payload,
351 })
352 }
353
354 async fn handle_incoming_message(
355 message: proto::Envelope,
356 peer: &Arc<Peer>,
357 connection_id: ConnectionId,
358 response_channels: &ResponseChannels,
359 ) {
360 if let Some(responding_to) = message.responding_to {
361 let channel = response_channels.lock().await.remove(&responding_to);
362 if let Some(mut tx) = channel {
363 tx.send(message).await.ok();
364 } else {
365 log::warn!("received RPC response to unknown request {}", responding_to);
366 }
367 } else {
368 let mut envelope = Some(message);
369 let mut handler_index = None;
370 let mut handler_was_dropped = false;
371 for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
372 if let Some(future) = handler(&mut envelope, connection_id) {
373 handler_was_dropped = future.await;
374 handler_index = Some(i);
375 break;
376 }
377 }
378
379 if let Some(handler_index) = handler_index {
380 if handler_was_dropped {
381 drop(peer.message_handlers.write().await.remove(handler_index));
382 }
383 } else {
384 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
385 }
386 }
387 }
388}
389
390impl<T> Clone for Receipt<T> {
391 fn clone(&self) -> Self {
392 Self {
393 sender_id: self.sender_id,
394 message_id: self.message_id,
395 payload_type: PhantomData,
396 }
397 }
398}
399
400impl<T> Copy for Receipt<T> {}
401
402impl fmt::Display for ConnectionId {
403 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404 self.0.fmt(f)
405 }
406}
407
408impl fmt::Display for PeerId {
409 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
410 self.0.fmt(f)
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use crate::test;
418 use postage::oneshot;
419
420 #[test]
421 fn test_request_response() {
422 smol::block_on(async move {
423 // create 2 clients connected to 1 server
424 let server = Peer::new();
425 let client1 = Peer::new();
426 let client2 = Peer::new();
427
428 let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
429 let (client1_conn_id, task1) = client1.add_connection(client1_to_server_conn).await;
430 let (_, task2) = server.add_connection(server_to_client_1_conn).await;
431
432 let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
433 let (client2_conn_id, task3) = client2.add_connection(client2_to_server_conn).await;
434 let (_, task4) = server.add_connection(server_to_client_2_conn).await;
435
436 smol::spawn(task1.run()).detach();
437 smol::spawn(task2.run()).detach();
438 smol::spawn(task3.run()).detach();
439 smol::spawn(task4.run()).detach();
440
441 // define the expected requests and responses
442 let request1 = proto::Auth {
443 user_id: 1,
444 access_token: "token-1".to_string(),
445 };
446 let response1 = proto::AuthResponse {
447 credentials_valid: true,
448 };
449 let request2 = proto::Auth {
450 user_id: 2,
451 access_token: "token-2".to_string(),
452 };
453 let response2 = proto::AuthResponse {
454 credentials_valid: false,
455 };
456 let request3 = proto::OpenBuffer {
457 worktree_id: 1,
458 path: "path/two".to_string(),
459 };
460 let response3 = proto::OpenBufferResponse {
461 buffer: Some(proto::Buffer {
462 id: 2,
463 content: "path/two content".to_string(),
464 history: vec![],
465 selections: vec![],
466 }),
467 };
468 let request4 = proto::OpenBuffer {
469 worktree_id: 2,
470 path: "path/one".to_string(),
471 };
472 let response4 = proto::OpenBufferResponse {
473 buffer: Some(proto::Buffer {
474 id: 1,
475 content: "path/one content".to_string(),
476 history: vec![],
477 selections: vec![],
478 }),
479 };
480
481 // on the server, respond to two requests for each client
482 let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
483 let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
484 let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
485 smol::spawn({
486 let request1 = request1.clone();
487 let request2 = request2.clone();
488 let request3 = request3.clone();
489 let request4 = request4.clone();
490 let response1 = response1.clone();
491 let response2 = response2.clone();
492 let response3 = response3.clone();
493 let response4 = response4.clone();
494 async move {
495 let msg = auth_rx.recv().await.unwrap();
496 assert_eq!(msg.payload, request1);
497 server
498 .respond(msg.receipt(), response1.clone())
499 .await
500 .unwrap();
501
502 let msg = auth_rx.recv().await.unwrap();
503 assert_eq!(msg.payload, request2.clone());
504 server
505 .respond(msg.receipt(), response2.clone())
506 .await
507 .unwrap();
508
509 let msg = open_buffer_rx.recv().await.unwrap();
510 assert_eq!(msg.payload, request3.clone());
511 server
512 .respond(msg.receipt(), response3.clone())
513 .await
514 .unwrap();
515
516 let msg = open_buffer_rx.recv().await.unwrap();
517 assert_eq!(msg.payload, request4.clone());
518 server
519 .respond(msg.receipt(), response4.clone())
520 .await
521 .unwrap();
522
523 server_done_tx.send(()).await.unwrap();
524 }
525 })
526 .detach();
527
528 assert_eq!(
529 client1.request(client1_conn_id, request1).await.unwrap(),
530 response1
531 );
532 assert_eq!(
533 client2.request(client2_conn_id, request2).await.unwrap(),
534 response2
535 );
536 assert_eq!(
537 client2.request(client2_conn_id, request3).await.unwrap(),
538 response3
539 );
540 assert_eq!(
541 client1.request(client1_conn_id, request4).await.unwrap(),
542 response4
543 );
544
545 client1.disconnect(client1_conn_id).await;
546 client2.disconnect(client1_conn_id).await;
547
548 server_done_rx.recv().await.unwrap();
549 });
550 }
551
552 #[test]
553 fn test_disconnect() {
554 smol::block_on(async move {
555 let (client_conn, mut server_conn) = test::Channel::bidirectional();
556
557 let client = Peer::new();
558 let (connection_id, handler) = client.add_connection(client_conn).await;
559 let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
560 postage::barrier::channel();
561 smol::spawn(async move {
562 handler.run().await.ok();
563 incoming_messages_ended_tx.send(()).await.unwrap();
564 })
565 .detach();
566 client.disconnect(connection_id).await;
567
568 incoming_messages_ended_rx.recv().await;
569 assert!(
570 futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
571 .await
572 .is_err()
573 );
574 });
575 }
576
577 #[test]
578 fn test_io_error() {
579 smol::block_on(async move {
580 let (client_conn, server_conn) = test::Channel::bidirectional();
581 drop(server_conn);
582
583 let client = Peer::new();
584 let (connection_id, handler) = client.add_connection(client_conn).await;
585 smol::spawn(handler.run()).detach();
586
587 let err = client
588 .request(
589 connection_id,
590 proto::Auth {
591 user_id: 42,
592 access_token: "token".to_string(),
593 },
594 )
595 .await
596 .unwrap_err();
597 assert_eq!(err.to_string(), "connection was closed");
598 });
599 }
600}