1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Result};
3use async_lock::{Mutex, RwLock};
4use futures::{
5 future::{BoxFuture, Either},
6 AsyncRead, AsyncWrite, FutureExt,
7};
8use postage::{
9 barrier, mpsc, oneshot,
10 prelude::{Sink, Stream},
11};
12use std::{
13 any::TypeId,
14 collections::{HashMap, HashSet},
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 pin::Pin,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23};
24
25type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
26type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
27
28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct ConnectionId(u32);
30
31#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
32pub struct PeerId(u32);
33
34struct Connection {
35 writer: Mutex<MessageStream<BoxedWriter>>,
36 reader: Mutex<MessageStream<BoxedReader>>,
37 response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
38 next_message_id: AtomicU32,
39}
40
41type MessageHandler = Box<
42 dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
43>;
44
45#[derive(Clone, Copy)]
46pub struct Receipt<T> {
47 sender_id: ConnectionId,
48 message_id: u32,
49 payload_type: PhantomData<T>,
50}
51
52pub struct TypedEnvelope<T> {
53 pub sender_id: ConnectionId,
54 pub original_sender_id: Option<PeerId>,
55 pub message_id: u32,
56 pub payload: T,
57}
58
59impl<T: RequestMessage> TypedEnvelope<T> {
60 pub fn receipt(&self) -> Receipt<T> {
61 Receipt {
62 sender_id: self.sender_id,
63 message_id: self.message_id,
64 payload_type: PhantomData,
65 }
66 }
67}
68
69pub struct Peer {
70 connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
71 connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
72 message_handlers: RwLock<Vec<MessageHandler>>,
73 handler_types: Mutex<HashSet<TypeId>>,
74 next_connection_id: AtomicU32,
75}
76
77impl Peer {
78 pub fn new() -> Arc<Self> {
79 Arc::new(Self {
80 connections: Default::default(),
81 connection_close_barriers: Default::default(),
82 message_handlers: Default::default(),
83 handler_types: Default::default(),
84 next_connection_id: Default::default(),
85 })
86 }
87
88 pub async fn add_message_handler<T: EnvelopedMessage>(
89 &self,
90 ) -> mpsc::Receiver<TypedEnvelope<T>> {
91 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
92 panic!("duplicate handler type");
93 }
94
95 let (tx, rx) = mpsc::channel(256);
96 self.message_handlers
97 .write()
98 .await
99 .push(Box::new(move |envelope, connection_id| {
100 if envelope.as_ref().map_or(false, T::matches_envelope) {
101 let envelope = Option::take(envelope).unwrap();
102 let mut tx = tx.clone();
103 Some(
104 async move {
105 tx.send(TypedEnvelope {
106 sender_id: connection_id,
107 original_sender_id: envelope.original_sender_id.map(PeerId),
108 message_id: envelope.id,
109 payload: T::from_envelope(envelope).unwrap(),
110 })
111 .await
112 .is_err()
113 }
114 .boxed(),
115 )
116 } else {
117 None
118 }
119 }));
120 rx
121 }
122
123 pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
124 where
125 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
126 {
127 let connection_id = ConnectionId(
128 self.next_connection_id
129 .fetch_add(1, atomic::Ordering::SeqCst),
130 );
131 self.connections.write().await.insert(
132 connection_id,
133 Arc::new(Connection {
134 reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
135 writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
136 response_channels: Default::default(),
137 next_message_id: Default::default(),
138 }),
139 );
140 connection_id
141 }
142
143 pub async fn disconnect(&self, connection_id: ConnectionId) {
144 self.connections.write().await.remove(&connection_id);
145 self.connection_close_barriers
146 .write()
147 .await
148 .remove(&connection_id);
149 }
150
151 pub async fn reset(&self) {
152 self.connections.write().await.clear();
153 self.connection_close_barriers.write().await.clear();
154 self.handler_types.lock().await.clear();
155 self.message_handlers.write().await.clear();
156 }
157
158 pub fn handle_messages(
159 self: &Arc<Self>,
160 connection_id: ConnectionId,
161 ) -> impl Future<Output = Result<()>> + 'static {
162 let (close_tx, mut close_rx) = barrier::channel();
163 let this = self.clone();
164 async move {
165 this.connection_close_barriers
166 .write()
167 .await
168 .insert(connection_id, close_tx);
169 let connection = this.connection(connection_id).await?;
170 let closed = close_rx.recv();
171 futures::pin_mut!(closed);
172
173 loop {
174 let mut reader = connection.reader.lock().await;
175 let read_message = reader.read_message();
176 futures::pin_mut!(read_message);
177
178 match futures::future::select(read_message, &mut closed).await {
179 Either::Left((Ok(incoming), _)) => {
180 if let Some(responding_to) = incoming.responding_to {
181 let channel = connection
182 .response_channels
183 .lock()
184 .await
185 .remove(&responding_to);
186 if let Some(mut tx) = channel {
187 tx.send(incoming).await.ok();
188 } else {
189 log::warn!(
190 "received RPC response to unknown request {}",
191 responding_to
192 );
193 }
194 } else {
195 let mut envelope = Some(incoming);
196 let mut handler_index = None;
197 let mut handler_was_dropped = false;
198 for (i, handler) in
199 this.message_handlers.read().await.iter().enumerate()
200 {
201 if let Some(future) = handler(&mut envelope, connection_id) {
202 handler_was_dropped = future.await;
203 handler_index = Some(i);
204 break;
205 }
206 }
207
208 if let Some(handler_index) = handler_index {
209 if handler_was_dropped {
210 drop(this.message_handlers.write().await.remove(handler_index));
211 }
212 } else {
213 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
214 }
215 }
216 }
217 Either::Left((Err(error), _)) => {
218 log::warn!("received invalid RPC message: {}", error);
219 Err(error)?;
220 }
221 Either::Right(_) => return Ok(()),
222 }
223 }
224 }
225 }
226
227 pub async fn receive<M: EnvelopedMessage>(
228 self: &Arc<Self>,
229 connection_id: ConnectionId,
230 ) -> Result<TypedEnvelope<M>> {
231 let connection = self.connection(connection_id).await?;
232 let envelope = connection.reader.lock().await.read_message().await?;
233 let original_sender_id = envelope.original_sender_id;
234 let message_id = envelope.id;
235 let payload =
236 M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
237 Ok(TypedEnvelope {
238 sender_id: connection_id,
239 original_sender_id: original_sender_id.map(PeerId),
240 message_id,
241 payload,
242 })
243 }
244
245 pub fn request<T: RequestMessage>(
246 self: &Arc<Self>,
247 receiver_id: ConnectionId,
248 request: T,
249 ) -> impl Future<Output = Result<T::Response>> {
250 self.request_internal(None, receiver_id, request)
251 }
252
253 pub fn forward_request<T: RequestMessage>(
254 self: &Arc<Self>,
255 sender_id: ConnectionId,
256 receiver_id: ConnectionId,
257 request: T,
258 ) -> impl Future<Output = Result<T::Response>> {
259 self.request_internal(Some(sender_id), receiver_id, request)
260 }
261
262 pub fn request_internal<T: RequestMessage>(
263 self: &Arc<Self>,
264 original_sender_id: Option<ConnectionId>,
265 receiver_id: ConnectionId,
266 request: T,
267 ) -> impl Future<Output = Result<T::Response>> {
268 let this = self.clone();
269 let (tx, mut rx) = oneshot::channel();
270 async move {
271 let connection = this.connection(receiver_id).await?;
272 let message_id = connection
273 .next_message_id
274 .fetch_add(1, atomic::Ordering::SeqCst);
275 connection
276 .response_channels
277 .lock()
278 .await
279 .insert(message_id, tx);
280 connection
281 .writer
282 .lock()
283 .await
284 .write_message(&request.into_envelope(
285 message_id,
286 None,
287 original_sender_id.map(|id| id.0),
288 ))
289 .await?;
290 let response = rx
291 .recv()
292 .await
293 .expect("response channel was unexpectedly dropped");
294 T::Response::from_envelope(response)
295 .ok_or_else(|| anyhow!("received response of the wrong type"))
296 }
297 }
298
299 pub fn send<T: EnvelopedMessage>(
300 self: &Arc<Self>,
301 connection_id: ConnectionId,
302 message: T,
303 ) -> impl Future<Output = Result<()>> {
304 let this = self.clone();
305 async move {
306 let connection = this.connection(connection_id).await?;
307 let message_id = connection
308 .next_message_id
309 .fetch_add(1, atomic::Ordering::SeqCst);
310 connection
311 .writer
312 .lock()
313 .await
314 .write_message(&message.into_envelope(message_id, None, None))
315 .await?;
316 Ok(())
317 }
318 }
319
320 pub fn forward_send<T: EnvelopedMessage>(
321 self: &Arc<Self>,
322 sender_id: ConnectionId,
323 receiver_id: ConnectionId,
324 message: T,
325 ) -> impl Future<Output = Result<()>> {
326 let this = self.clone();
327 async move {
328 let connection = this.connection(receiver_id).await?;
329 let message_id = connection
330 .next_message_id
331 .fetch_add(1, atomic::Ordering::SeqCst);
332 connection
333 .writer
334 .lock()
335 .await
336 .write_message(&message.into_envelope(message_id, None, Some(sender_id.0)))
337 .await?;
338 Ok(())
339 }
340 }
341
342 pub fn respond<T: RequestMessage>(
343 self: &Arc<Self>,
344 receipt: Receipt<T>,
345 response: T::Response,
346 ) -> impl Future<Output = Result<()>> {
347 let this = self.clone();
348 async move {
349 let connection = this.connection(receipt.sender_id).await?;
350 let message_id = connection
351 .next_message_id
352 .fetch_add(1, atomic::Ordering::SeqCst);
353 connection
354 .writer
355 .lock()
356 .await
357 .write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
358 .await?;
359 Ok(())
360 }
361 }
362
363 async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
364 Ok(self
365 .connections
366 .read()
367 .await
368 .get(&id)
369 .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
370 .clone())
371 }
372}
373
374impl fmt::Display for ConnectionId {
375 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376 self.0.fmt(f)
377 }
378}
379
380impl fmt::Display for PeerId {
381 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
382 self.0.fmt(f)
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389 use smol::{
390 io::AsyncWriteExt,
391 net::unix::{UnixListener, UnixStream},
392 };
393 use std::io;
394 use tempdir::TempDir;
395
396 #[test]
397 fn test_request_response() {
398 smol::block_on(async move {
399 // create socket
400 let socket_dir_path = TempDir::new("test-request-response").unwrap();
401 let socket_path = socket_dir_path.path().join("test.sock");
402 let listener = UnixListener::bind(&socket_path).unwrap();
403
404 // create 2 clients connected to 1 server
405 let server = Peer::new();
406 let client1 = Peer::new();
407 let client2 = Peer::new();
408 let client1_conn_id = client1
409 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
410 .await;
411 let client2_conn_id = client2
412 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
413 .await;
414 let server_conn_id1 = server
415 .add_connection(listener.accept().await.unwrap().0)
416 .await;
417 let server_conn_id2 = server
418 .add_connection(listener.accept().await.unwrap().0)
419 .await;
420 smol::spawn(client1.handle_messages(client1_conn_id)).detach();
421 smol::spawn(client2.handle_messages(client2_conn_id)).detach();
422 smol::spawn(server.handle_messages(server_conn_id1)).detach();
423 smol::spawn(server.handle_messages(server_conn_id2)).detach();
424
425 // define the expected requests and responses
426 let request1 = proto::Auth {
427 user_id: 1,
428 access_token: "token-1".to_string(),
429 };
430 let response1 = proto::AuthResponse {
431 credentials_valid: true,
432 };
433 let request2 = proto::Auth {
434 user_id: 2,
435 access_token: "token-2".to_string(),
436 };
437 let response2 = proto::AuthResponse {
438 credentials_valid: false,
439 };
440 let request3 = proto::OpenBuffer {
441 worktree_id: 1,
442 path: "path/two".to_string(),
443 };
444 let response3 = proto::OpenBufferResponse {
445 buffer: Some(proto::Buffer {
446 id: 2,
447 content: "path/two content".to_string(),
448 history: vec![],
449 }),
450 };
451 let request4 = proto::OpenBuffer {
452 worktree_id: 2,
453 path: "path/one".to_string(),
454 };
455 let response4 = proto::OpenBufferResponse {
456 buffer: Some(proto::Buffer {
457 id: 1,
458 content: "path/one content".to_string(),
459 history: vec![],
460 }),
461 };
462
463 // on the server, respond to two requests for each client
464 let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
465 let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
466 let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
467 smol::spawn({
468 let request1 = request1.clone();
469 let request2 = request2.clone();
470 let request3 = request3.clone();
471 let request4 = request4.clone();
472 let response1 = response1.clone();
473 let response2 = response2.clone();
474 let response3 = response3.clone();
475 let response4 = response4.clone();
476 async move {
477 let msg = auth_rx.recv().await.unwrap();
478 assert_eq!(msg.payload, request1);
479 server
480 .respond(msg.receipt(), response1.clone())
481 .await
482 .unwrap();
483
484 let msg = auth_rx.recv().await.unwrap();
485 assert_eq!(msg.payload, request2.clone());
486 server
487 .respond(msg.receipt(), response2.clone())
488 .await
489 .unwrap();
490
491 let msg = open_buffer_rx.recv().await.unwrap();
492 assert_eq!(msg.payload, request3.clone());
493 server
494 .respond(msg.receipt(), response3.clone())
495 .await
496 .unwrap();
497
498 let msg = open_buffer_rx.recv().await.unwrap();
499 assert_eq!(msg.payload, request4.clone());
500 server
501 .respond(msg.receipt(), response4.clone())
502 .await
503 .unwrap();
504
505 server_done_tx.send(()).await.unwrap();
506 }
507 })
508 .detach();
509
510 assert_eq!(
511 client1.request(client1_conn_id, request1).await.unwrap(),
512 response1
513 );
514 assert_eq!(
515 client2.request(client2_conn_id, request2).await.unwrap(),
516 response2
517 );
518 assert_eq!(
519 client2.request(client2_conn_id, request3).await.unwrap(),
520 response3
521 );
522 assert_eq!(
523 client1.request(client1_conn_id, request4).await.unwrap(),
524 response4
525 );
526
527 client1.disconnect(client1_conn_id).await;
528 client2.disconnect(client1_conn_id).await;
529
530 server_done_rx.recv().await.unwrap();
531 });
532 }
533
534 #[test]
535 fn test_disconnect() {
536 smol::block_on(async move {
537 let socket_dir_path = TempDir::new("drop-client").unwrap();
538 let socket_path = socket_dir_path.path().join(".sock");
539 let listener = UnixListener::bind(&socket_path).unwrap();
540 let client_conn = UnixStream::connect(&socket_path).await.unwrap();
541 let (mut server_conn, _) = listener.accept().await.unwrap();
542
543 let client = Peer::new();
544 let connection_id = client.add_connection(client_conn).await;
545 let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
546 barrier::channel();
547 let handle_messages = client.handle_messages(connection_id);
548 smol::spawn(async move {
549 handle_messages.await.ok();
550 incoming_messages_ended_tx.send(()).await.unwrap();
551 })
552 .detach();
553 client.disconnect(connection_id).await;
554
555 incoming_messages_ended_rx.recv().await;
556
557 let err = server_conn.write(&[]).await.unwrap_err();
558 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
559 });
560 }
561
562 #[test]
563 fn test_io_error() {
564 smol::block_on(async move {
565 let socket_dir_path = TempDir::new("io-error").unwrap();
566 let socket_path = socket_dir_path.path().join(".sock");
567 let _listener = UnixListener::bind(&socket_path).unwrap();
568 let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
569 client_conn.close().await.unwrap();
570
571 let client = Peer::new();
572 let connection_id = client.add_connection(client_conn).await;
573 smol::spawn(client.handle_messages(connection_id)).detach();
574
575 let err = client
576 .request(
577 connection_id,
578 proto::Auth {
579 user_id: 42,
580 access_token: "token".to_string(),
581 },
582 )
583 .await
584 .unwrap_err();
585 assert_eq!(
586 err.downcast_ref::<io::Error>().unwrap().kind(),
587 io::ErrorKind::BrokenPipe
588 );
589 });
590 }
591}