1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Result};
3use async_lock::{Mutex, RwLock};
4use futures::{
5 future::{BoxFuture, Either},
6 AsyncRead, AsyncWrite, FutureExt,
7};
8use postage::{
9 barrier, mpsc, oneshot,
10 prelude::{Sink, Stream},
11};
12use std::{
13 any::TypeId,
14 collections::{HashMap, HashSet},
15 fmt,
16 future::Future,
17 marker::PhantomData,
18 pin::Pin,
19 sync::{
20 atomic::{self, AtomicU32},
21 Arc,
22 },
23};
24
25type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
26type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
27
28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct ConnectionId(u32);
30
31#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
32pub struct PeerId(u32);
33
34struct Connection {
35 writer: Mutex<MessageStream<BoxedWriter>>,
36 reader: Mutex<MessageStream<BoxedReader>>,
37 response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
38 next_message_id: AtomicU32,
39}
40
41type MessageHandler = Box<
42 dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
43>;
44
45#[derive(Clone, Copy)]
46pub struct Receipt<T> {
47 sender_id: ConnectionId,
48 message_id: u32,
49 payload_type: PhantomData<T>,
50}
51
52pub struct TypedEnvelope<T> {
53 pub sender_id: ConnectionId,
54 pub original_sender_id: Option<PeerId>,
55 pub message_id: u32,
56 pub payload: T,
57}
58
59impl<T: RequestMessage> TypedEnvelope<T> {
60 pub fn receipt(&self) -> Receipt<T> {
61 Receipt {
62 sender_id: self.sender_id,
63 message_id: self.message_id,
64 payload_type: PhantomData,
65 }
66 }
67}
68
69pub struct Peer {
70 connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
71 connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
72 message_handlers: RwLock<Vec<MessageHandler>>,
73 handler_types: Mutex<HashSet<TypeId>>,
74 next_connection_id: AtomicU32,
75}
76
77impl Peer {
78 pub fn new() -> Arc<Self> {
79 Arc::new(Self {
80 connections: Default::default(),
81 connection_close_barriers: Default::default(),
82 message_handlers: Default::default(),
83 handler_types: Default::default(),
84 next_connection_id: Default::default(),
85 })
86 }
87
88 pub async fn add_message_handler<T: EnvelopedMessage>(
89 &self,
90 ) -> mpsc::Receiver<TypedEnvelope<T>> {
91 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
92 panic!("duplicate handler type");
93 }
94
95 let (tx, rx) = mpsc::channel(256);
96 self.message_handlers
97 .write()
98 .await
99 .push(Box::new(move |envelope, connection_id| {
100 if envelope.as_ref().map_or(false, T::matches_envelope) {
101 let envelope = Option::take(envelope).unwrap();
102 let mut tx = tx.clone();
103 Some(
104 async move {
105 tx.send(TypedEnvelope {
106 sender_id: connection_id,
107 original_sender_id: envelope.original_sender_id.map(PeerId),
108 message_id: envelope.id,
109 payload: T::from_envelope(envelope).unwrap(),
110 })
111 .await
112 .is_err()
113 }
114 .boxed(),
115 )
116 } else {
117 None
118 }
119 }));
120 rx
121 }
122
123 pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
124 where
125 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
126 {
127 let connection_id = ConnectionId(
128 self.next_connection_id
129 .fetch_add(1, atomic::Ordering::SeqCst),
130 );
131 self.connections.write().await.insert(
132 connection_id,
133 Arc::new(Connection {
134 reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
135 writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
136 response_channels: Default::default(),
137 next_message_id: Default::default(),
138 }),
139 );
140 connection_id
141 }
142
143 pub async fn disconnect(&self, connection_id: ConnectionId) {
144 self.connections.write().await.remove(&connection_id);
145 self.connection_close_barriers
146 .write()
147 .await
148 .remove(&connection_id);
149 }
150
151 pub fn handle_messages(
152 self: &Arc<Self>,
153 connection_id: ConnectionId,
154 ) -> impl Future<Output = Result<()>> + 'static {
155 let (close_tx, mut close_rx) = barrier::channel();
156 let this = self.clone();
157 async move {
158 this.connection_close_barriers
159 .write()
160 .await
161 .insert(connection_id, close_tx);
162 let connection = this.connection(connection_id).await?;
163 let closed = close_rx.recv();
164 futures::pin_mut!(closed);
165
166 loop {
167 let mut reader = connection.reader.lock().await;
168 let read_message = reader.read_message();
169 futures::pin_mut!(read_message);
170
171 match futures::future::select(read_message, &mut closed).await {
172 Either::Left((Ok(incoming), _)) => {
173 if let Some(responding_to) = incoming.responding_to {
174 let channel = connection
175 .response_channels
176 .lock()
177 .await
178 .remove(&responding_to);
179 if let Some(mut tx) = channel {
180 tx.send(incoming).await.ok();
181 } else {
182 log::warn!(
183 "received RPC response to unknown request {}",
184 responding_to
185 );
186 }
187 } else {
188 let mut envelope = Some(incoming);
189 let mut handler_index = None;
190 let mut handler_was_dropped = false;
191 for (i, handler) in
192 this.message_handlers.read().await.iter().enumerate()
193 {
194 if let Some(future) = handler(&mut envelope, connection_id) {
195 handler_was_dropped = future.await;
196 handler_index = Some(i);
197 break;
198 }
199 }
200
201 if let Some(handler_index) = handler_index {
202 if handler_was_dropped {
203 drop(this.message_handlers.write().await.remove(handler_index));
204 }
205 } else {
206 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
207 }
208 }
209 }
210 Either::Left((Err(error), _)) => {
211 log::warn!("received invalid RPC message: {}", error);
212 Err(error)?;
213 }
214 Either::Right(_) => return Ok(()),
215 }
216 }
217 }
218 }
219
220 pub async fn receive<M: EnvelopedMessage>(
221 self: &Arc<Self>,
222 connection_id: ConnectionId,
223 ) -> Result<TypedEnvelope<M>> {
224 let connection = self.connection(connection_id).await?;
225 let envelope = connection.reader.lock().await.read_message().await?;
226 let original_sender_id = envelope.original_sender_id;
227 let message_id = envelope.id;
228 let payload =
229 M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
230 Ok(TypedEnvelope {
231 sender_id: connection_id,
232 original_sender_id: original_sender_id.map(PeerId),
233 message_id,
234 payload,
235 })
236 }
237
238 pub fn request<T: RequestMessage>(
239 self: &Arc<Self>,
240 receiver_id: ConnectionId,
241 request: T,
242 ) -> impl Future<Output = Result<T::Response>> {
243 self.request_internal(None, receiver_id, request)
244 }
245
246 pub fn forward_request<T: RequestMessage>(
247 self: &Arc<Self>,
248 sender_id: ConnectionId,
249 receiver_id: ConnectionId,
250 request: T,
251 ) -> impl Future<Output = Result<T::Response>> {
252 self.request_internal(Some(sender_id), receiver_id, request)
253 }
254
255 pub fn request_internal<T: RequestMessage>(
256 self: &Arc<Self>,
257 original_sender_id: Option<ConnectionId>,
258 receiver_id: ConnectionId,
259 request: T,
260 ) -> impl Future<Output = Result<T::Response>> {
261 let this = self.clone();
262 let (tx, mut rx) = oneshot::channel();
263 async move {
264 let connection = this.connection(receiver_id).await?;
265 let message_id = connection
266 .next_message_id
267 .fetch_add(1, atomic::Ordering::SeqCst);
268 connection
269 .response_channels
270 .lock()
271 .await
272 .insert(message_id, tx);
273 connection
274 .writer
275 .lock()
276 .await
277 .write_message(&request.into_envelope(
278 message_id,
279 None,
280 original_sender_id.map(|id| id.0),
281 ))
282 .await?;
283 let response = rx
284 .recv()
285 .await
286 .expect("response channel was unexpectedly dropped");
287 T::Response::from_envelope(response)
288 .ok_or_else(|| anyhow!("received response of the wrong type"))
289 }
290 }
291
292 pub fn send<T: EnvelopedMessage>(
293 self: &Arc<Self>,
294 connection_id: ConnectionId,
295 message: T,
296 ) -> impl Future<Output = Result<()>> {
297 let this = self.clone();
298 async move {
299 let connection = this.connection(connection_id).await?;
300 let message_id = connection
301 .next_message_id
302 .fetch_add(1, atomic::Ordering::SeqCst);
303 connection
304 .writer
305 .lock()
306 .await
307 .write_message(&message.into_envelope(message_id, None, None))
308 .await?;
309 Ok(())
310 }
311 }
312
313 pub fn respond<T: RequestMessage>(
314 self: &Arc<Self>,
315 receipt: Receipt<T>,
316 response: T::Response,
317 ) -> impl Future<Output = Result<()>> {
318 let this = self.clone();
319 async move {
320 let connection = this.connection(receipt.sender_id).await?;
321 let message_id = connection
322 .next_message_id
323 .fetch_add(1, atomic::Ordering::SeqCst);
324 connection
325 .writer
326 .lock()
327 .await
328 .write_message(&response.into_envelope(message_id, Some(receipt.message_id), None))
329 .await?;
330 Ok(())
331 }
332 }
333
334 async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
335 Ok(self
336 .connections
337 .read()
338 .await
339 .get(&id)
340 .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
341 .clone())
342 }
343}
344
345impl fmt::Display for ConnectionId {
346 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347 self.0.fmt(f)
348 }
349}
350
351impl fmt::Display for PeerId {
352 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353 self.0.fmt(f)
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use smol::{
361 io::AsyncWriteExt,
362 net::unix::{UnixListener, UnixStream},
363 };
364 use std::io;
365 use tempdir::TempDir;
366
367 #[test]
368 fn test_request_response() {
369 smol::block_on(async move {
370 // create socket
371 let socket_dir_path = TempDir::new("test-request-response").unwrap();
372 let socket_path = socket_dir_path.path().join("test.sock");
373 let listener = UnixListener::bind(&socket_path).unwrap();
374
375 // create 2 clients connected to 1 server
376 let server = Peer::new();
377 let client1 = Peer::new();
378 let client2 = Peer::new();
379 let client1_conn_id = client1
380 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
381 .await;
382 let client2_conn_id = client2
383 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
384 .await;
385 let server_conn_id1 = server
386 .add_connection(listener.accept().await.unwrap().0)
387 .await;
388 let server_conn_id2 = server
389 .add_connection(listener.accept().await.unwrap().0)
390 .await;
391 smol::spawn(client1.handle_messages(client1_conn_id)).detach();
392 smol::spawn(client2.handle_messages(client2_conn_id)).detach();
393 smol::spawn(server.handle_messages(server_conn_id1)).detach();
394 smol::spawn(server.handle_messages(server_conn_id2)).detach();
395
396 // define the expected requests and responses
397 let request1 = proto::Auth {
398 user_id: 1,
399 access_token: "token-1".to_string(),
400 };
401 let response1 = proto::AuthResponse {
402 credentials_valid: true,
403 };
404 let request2 = proto::Auth {
405 user_id: 2,
406 access_token: "token-2".to_string(),
407 };
408 let response2 = proto::AuthResponse {
409 credentials_valid: false,
410 };
411 let request3 = proto::OpenBuffer {
412 worktree_id: 102,
413 path: "path/two".to_string(),
414 };
415 let response3 = proto::OpenBufferResponse {
416 buffer: Some(proto::Buffer {
417 id: 1001,
418 path: "path/two".to_string(),
419 content: "path/two content".to_string(),
420 history: vec![],
421 }),
422 };
423 let request4 = proto::OpenBuffer {
424 worktree_id: 101,
425 path: "path/one".to_string(),
426 };
427 let response4 = proto::OpenBufferResponse {
428 buffer: Some(proto::Buffer {
429 id: 1002,
430 path: "path/one".to_string(),
431 content: "path/one content".to_string(),
432 history: vec![],
433 }),
434 };
435
436 // on the server, respond to two requests for each client
437 let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
438 let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
439 let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
440 smol::spawn({
441 let request1 = request1.clone();
442 let request2 = request2.clone();
443 let request3 = request3.clone();
444 let request4 = request4.clone();
445 let response1 = response1.clone();
446 let response2 = response2.clone();
447 let response3 = response3.clone();
448 let response4 = response4.clone();
449 async move {
450 let msg = auth_rx.recv().await.unwrap();
451 assert_eq!(msg.payload, request1);
452 server
453 .respond(msg.receipt(), response1.clone())
454 .await
455 .unwrap();
456
457 let msg = auth_rx.recv().await.unwrap();
458 assert_eq!(msg.payload, request2.clone());
459 server
460 .respond(msg.receipt(), response2.clone())
461 .await
462 .unwrap();
463
464 let msg = open_buffer_rx.recv().await.unwrap();
465 assert_eq!(msg.payload, request3.clone());
466 server
467 .respond(msg.receipt(), response3.clone())
468 .await
469 .unwrap();
470
471 let msg = open_buffer_rx.recv().await.unwrap();
472 assert_eq!(msg.payload, request4.clone());
473 server
474 .respond(msg.receipt(), response4.clone())
475 .await
476 .unwrap();
477
478 server_done_tx.send(()).await.unwrap();
479 }
480 })
481 .detach();
482
483 assert_eq!(
484 client1.request(client1_conn_id, request1).await.unwrap(),
485 response1
486 );
487 assert_eq!(
488 client2.request(client2_conn_id, request2).await.unwrap(),
489 response2
490 );
491 assert_eq!(
492 client2.request(client2_conn_id, request3).await.unwrap(),
493 response3
494 );
495 assert_eq!(
496 client1.request(client1_conn_id, request4).await.unwrap(),
497 response4
498 );
499
500 client1.disconnect(client1_conn_id).await;
501 client2.disconnect(client1_conn_id).await;
502
503 server_done_rx.recv().await.unwrap();
504 });
505 }
506
507 #[test]
508 fn test_disconnect() {
509 smol::block_on(async move {
510 let socket_dir_path = TempDir::new("drop-client").unwrap();
511 let socket_path = socket_dir_path.path().join(".sock");
512 let listener = UnixListener::bind(&socket_path).unwrap();
513 let client_conn = UnixStream::connect(&socket_path).await.unwrap();
514 let (mut server_conn, _) = listener.accept().await.unwrap();
515
516 let client = Peer::new();
517 let connection_id = client.add_connection(client_conn).await;
518 let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
519 barrier::channel();
520 let handle_messages = client.handle_messages(connection_id);
521 smol::spawn(async move {
522 handle_messages.await.ok();
523 incoming_messages_ended_tx.send(()).await.unwrap();
524 })
525 .detach();
526 client.disconnect(connection_id).await;
527
528 incoming_messages_ended_rx.recv().await;
529
530 let err = server_conn.write(&[]).await.unwrap_err();
531 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
532 });
533 }
534
535 #[test]
536 fn test_io_error() {
537 smol::block_on(async move {
538 let socket_dir_path = TempDir::new("io-error").unwrap();
539 let socket_path = socket_dir_path.path().join(".sock");
540 let _listener = UnixListener::bind(&socket_path).unwrap();
541 let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
542 client_conn.close().await.unwrap();
543
544 let client = Peer::new();
545 let connection_id = client.add_connection(client_conn).await;
546 smol::spawn(client.handle_messages(connection_id)).detach();
547
548 let err = client
549 .request(
550 connection_id,
551 proto::Auth {
552 user_id: 42,
553 access_token: "token".to_string(),
554 },
555 )
556 .await
557 .unwrap_err();
558 assert_eq!(
559 err.downcast_ref::<io::Error>().unwrap().kind(),
560 io::ErrorKind::BrokenPipe
561 );
562 });
563 }
564}