1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Context, Result};
3use async_lock::{Mutex, RwLock};
4use futures::{future::BoxFuture, AsyncRead, AsyncWrite, FutureExt};
5use postage::{
6 mpsc,
7 prelude::{Sink, Stream},
8};
9use std::{
10 any::TypeId,
11 collections::{HashMap, HashSet},
12 fmt,
13 future::Future,
14 marker::PhantomData,
15 sync::{
16 atomic::{self, AtomicU32},
17 Arc,
18 },
19};
20
21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
22pub struct ConnectionId(pub u32);
23
24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
25pub struct PeerId(pub u32);
26
27type MessageHandler = Box<
28 dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
29>;
30
31pub struct Receipt<T> {
32 sender_id: ConnectionId,
33 message_id: u32,
34 payload_type: PhantomData<T>,
35}
36
37pub struct TypedEnvelope<T> {
38 pub sender_id: ConnectionId,
39 original_sender_id: Option<PeerId>,
40 pub message_id: u32,
41 pub payload: T,
42}
43
44impl<T> TypedEnvelope<T> {
45 pub fn original_sender_id(&self) -> Result<PeerId> {
46 self.original_sender_id
47 .ok_or_else(|| anyhow!("missing original_sender_id"))
48 }
49}
50
51impl<T: RequestMessage> TypedEnvelope<T> {
52 pub fn receipt(&self) -> Receipt<T> {
53 Receipt {
54 sender_id: self.sender_id,
55 message_id: self.message_id,
56 payload_type: PhantomData,
57 }
58 }
59}
60
61pub struct Peer {
62 connections: RwLock<HashMap<ConnectionId, Connection>>,
63 message_handlers: RwLock<Vec<MessageHandler>>,
64 handler_types: Mutex<HashSet<TypeId>>,
65 next_connection_id: AtomicU32,
66}
67
68#[derive(Clone)]
69struct Connection {
70 outgoing_tx: mpsc::Sender<proto::Envelope>,
71 next_message_id: Arc<AtomicU32>,
72 response_channels: ResponseChannels,
73}
74
75pub struct ConnectionHandler<Conn> {
76 peer: Arc<Peer>,
77 connection_id: ConnectionId,
78 response_channels: ResponseChannels,
79 outgoing_rx: mpsc::Receiver<proto::Envelope>,
80 reader: MessageStream<Conn>,
81 writer: MessageStream<Conn>,
82}
83
84type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
85
86impl Peer {
87 pub fn new() -> Arc<Self> {
88 Arc::new(Self {
89 connections: Default::default(),
90 message_handlers: Default::default(),
91 handler_types: Default::default(),
92 next_connection_id: Default::default(),
93 })
94 }
95
96 pub async fn add_message_handler<T: EnvelopedMessage>(
97 &self,
98 ) -> mpsc::Receiver<TypedEnvelope<T>> {
99 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
100 panic!("duplicate handler type");
101 }
102
103 let (tx, rx) = mpsc::channel(256);
104 self.message_handlers
105 .write()
106 .await
107 .push(Box::new(move |envelope, connection_id| {
108 if envelope.as_ref().map_or(false, T::matches_envelope) {
109 let envelope = Option::take(envelope).unwrap();
110 let mut tx = tx.clone();
111 Some(
112 async move {
113 tx.send(TypedEnvelope {
114 sender_id: connection_id,
115 original_sender_id: envelope.original_sender_id.map(PeerId),
116 message_id: envelope.id,
117 payload: T::from_envelope(envelope).unwrap(),
118 })
119 .await
120 .is_err()
121 }
122 .boxed(),
123 )
124 } else {
125 None
126 }
127 }));
128 rx
129 }
130
131 pub async fn add_connection<Conn>(
132 self: &Arc<Self>,
133 conn: Conn,
134 ) -> (ConnectionId, ConnectionHandler<Conn>)
135 where
136 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
137 {
138 let connection_id = ConnectionId(
139 self.next_connection_id
140 .fetch_add(1, atomic::Ordering::SeqCst),
141 );
142 let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
143 let connection = Connection {
144 outgoing_tx,
145 next_message_id: Default::default(),
146 response_channels: Default::default(),
147 };
148 let handler = ConnectionHandler {
149 peer: self.clone(),
150 connection_id,
151 response_channels: connection.response_channels.clone(),
152 outgoing_rx,
153 reader: MessageStream::new(conn.clone()),
154 writer: MessageStream::new(conn),
155 };
156 self.connections
157 .write()
158 .await
159 .insert(connection_id, connection);
160 (connection_id, handler)
161 }
162
163 pub async fn disconnect(&self, connection_id: ConnectionId) {
164 self.connections.write().await.remove(&connection_id);
165 }
166
167 pub async fn reset(&self) {
168 self.connections.write().await.clear();
169 self.handler_types.lock().await.clear();
170 self.message_handlers.write().await.clear();
171 }
172
173 pub fn request<T: RequestMessage>(
174 self: &Arc<Self>,
175 receiver_id: ConnectionId,
176 request: T,
177 ) -> impl Future<Output = Result<T::Response>> {
178 self.request_internal(None, receiver_id, request)
179 }
180
181 pub fn forward_request<T: RequestMessage>(
182 self: &Arc<Self>,
183 sender_id: ConnectionId,
184 receiver_id: ConnectionId,
185 request: T,
186 ) -> impl Future<Output = Result<T::Response>> {
187 self.request_internal(Some(sender_id), receiver_id, request)
188 }
189
190 pub fn request_internal<T: RequestMessage>(
191 self: &Arc<Self>,
192 original_sender_id: Option<ConnectionId>,
193 receiver_id: ConnectionId,
194 request: T,
195 ) -> impl Future<Output = Result<T::Response>> {
196 let this = self.clone();
197 let (tx, mut rx) = mpsc::channel(1);
198 async move {
199 let mut connection = this.connection(receiver_id).await?;
200 let message_id = connection
201 .next_message_id
202 .fetch_add(1, atomic::Ordering::SeqCst);
203 connection
204 .response_channels
205 .lock()
206 .await
207 .insert(message_id, tx);
208 connection
209 .outgoing_tx
210 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
211 .await?;
212 let response = rx
213 .recv()
214 .await
215 .ok_or_else(|| anyhow!("connection was closed"))?;
216 T::Response::from_envelope(response)
217 .ok_or_else(|| anyhow!("received response of the wrong type"))
218 }
219 }
220
221 pub fn send<T: EnvelopedMessage>(
222 self: &Arc<Self>,
223 receiver_id: ConnectionId,
224 message: T,
225 ) -> impl Future<Output = Result<()>> {
226 let this = self.clone();
227 async move {
228 let mut connection = this.connection(receiver_id).await?;
229 let message_id = connection
230 .next_message_id
231 .fetch_add(1, atomic::Ordering::SeqCst);
232 connection
233 .outgoing_tx
234 .send(message.into_envelope(message_id, None, None))
235 .await?;
236 Ok(())
237 }
238 }
239
240 pub fn forward_send<T: EnvelopedMessage>(
241 self: &Arc<Self>,
242 sender_id: ConnectionId,
243 receiver_id: ConnectionId,
244 message: T,
245 ) -> impl Future<Output = Result<()>> {
246 let this = self.clone();
247 async move {
248 let mut connection = this.connection(receiver_id).await?;
249 let message_id = connection
250 .next_message_id
251 .fetch_add(1, atomic::Ordering::SeqCst);
252 connection
253 .outgoing_tx
254 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
255 .await?;
256 Ok(())
257 }
258 }
259
260 pub fn respond<T: RequestMessage>(
261 self: &Arc<Self>,
262 receipt: Receipt<T>,
263 response: T::Response,
264 ) -> impl Future<Output = Result<()>> {
265 let this = self.clone();
266 async move {
267 let mut connection = this.connection(receipt.sender_id).await?;
268 let message_id = connection
269 .next_message_id
270 .fetch_add(1, atomic::Ordering::SeqCst);
271 connection
272 .outgoing_tx
273 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
274 .await?;
275 Ok(())
276 }
277 }
278
279 fn connection(
280 self: &Arc<Self>,
281 connection_id: ConnectionId,
282 ) -> impl Future<Output = Result<Connection>> {
283 let this = self.clone();
284 async move {
285 let connections = this.connections.read().await;
286 let connection = connections
287 .get(&connection_id)
288 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
289 Ok(connection.clone())
290 }
291 }
292}
293
294impl<Conn> ConnectionHandler<Conn>
295where
296 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
297{
298 pub async fn run(mut self) -> Result<()> {
299 loop {
300 let read_message = self.reader.read_message().fuse();
301 futures::pin_mut!(read_message);
302 loop {
303 futures::select! {
304 incoming = read_message => match incoming {
305 Ok(incoming) => {
306 Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
307 break;
308 }
309 Err(error) => {
310 self.response_channels.lock().await.clear();
311 Err(error).context("received invalid RPC message")?;
312 }
313 },
314 outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
315 Some(outgoing) => {
316 if let Err(result) = self.writer.write_message(&outgoing).await {
317 self.response_channels.lock().await.clear();
318 Err(result).context("failed to write RPC message")?;
319 }
320 }
321 None => return Ok(()),
322 }
323 }
324 }
325 }
326 }
327
328 pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
329 let envelope = self.reader.read_message().await?;
330 let original_sender_id = envelope.original_sender_id;
331 let message_id = envelope.id;
332 let payload =
333 M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
334 Ok(TypedEnvelope {
335 sender_id: self.connection_id,
336 original_sender_id: original_sender_id.map(PeerId),
337 message_id,
338 payload,
339 })
340 }
341
342 async fn handle_incoming_message(
343 message: proto::Envelope,
344 peer: &Arc<Peer>,
345 connection_id: ConnectionId,
346 response_channels: &ResponseChannels,
347 ) {
348 if let Some(responding_to) = message.responding_to {
349 let channel = response_channels.lock().await.remove(&responding_to);
350 if let Some(mut tx) = channel {
351 tx.send(message).await.ok();
352 } else {
353 log::warn!("received RPC response to unknown request {}", responding_to);
354 }
355 } else {
356 let mut envelope = Some(message);
357 let mut handler_index = None;
358 let mut handler_was_dropped = false;
359 for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
360 if let Some(future) = handler(&mut envelope, connection_id) {
361 handler_was_dropped = future.await;
362 handler_index = Some(i);
363 break;
364 }
365 }
366
367 if let Some(handler_index) = handler_index {
368 if handler_was_dropped {
369 drop(peer.message_handlers.write().await.remove(handler_index));
370 }
371 } else {
372 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
373 }
374 }
375 }
376}
377
378impl<T> Clone for Receipt<T> {
379 fn clone(&self) -> Self {
380 Self {
381 sender_id: self.sender_id,
382 message_id: self.message_id,
383 payload_type: PhantomData,
384 }
385 }
386}
387
388impl<T> Copy for Receipt<T> {}
389
390impl fmt::Display for ConnectionId {
391 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392 self.0.fmt(f)
393 }
394}
395
396impl fmt::Display for PeerId {
397 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
398 self.0.fmt(f)
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use postage::oneshot;
406 use smol::{
407 io::AsyncWriteExt,
408 net::unix::{UnixListener, UnixStream},
409 };
410 use std::io;
411 use tempdir::TempDir;
412
413 #[test]
414 fn test_request_response() {
415 smol::block_on(async move {
416 // create socket
417 let socket_dir_path = TempDir::new("test-request-response").unwrap();
418 let socket_path = socket_dir_path.path().join("test.sock");
419 let listener = UnixListener::bind(&socket_path).unwrap();
420
421 // create 2 clients connected to 1 server
422 let server = Peer::new();
423 let client1 = Peer::new();
424 let client2 = Peer::new();
425 let (client1_conn_id, task1) = client1
426 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
427 .await;
428 let (client2_conn_id, task2) = client2
429 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
430 .await;
431 let (_, task3) = server
432 .add_connection(listener.accept().await.unwrap().0)
433 .await;
434 let (_, task4) = server
435 .add_connection(listener.accept().await.unwrap().0)
436 .await;
437 smol::spawn(task1.run()).detach();
438 smol::spawn(task2.run()).detach();
439 smol::spawn(task3.run()).detach();
440 smol::spawn(task4.run()).detach();
441
442 // define the expected requests and responses
443 let request1 = proto::Auth {
444 user_id: 1,
445 access_token: "token-1".to_string(),
446 };
447 let response1 = proto::AuthResponse {
448 credentials_valid: true,
449 };
450 let request2 = proto::Auth {
451 user_id: 2,
452 access_token: "token-2".to_string(),
453 };
454 let response2 = proto::AuthResponse {
455 credentials_valid: false,
456 };
457 let request3 = proto::OpenBuffer {
458 worktree_id: 1,
459 path: "path/two".to_string(),
460 };
461 let response3 = proto::OpenBufferResponse {
462 buffer: Some(proto::Buffer {
463 id: 2,
464 content: "path/two content".to_string(),
465 history: vec![],
466 selections: vec![],
467 }),
468 };
469 let request4 = proto::OpenBuffer {
470 worktree_id: 2,
471 path: "path/one".to_string(),
472 };
473 let response4 = proto::OpenBufferResponse {
474 buffer: Some(proto::Buffer {
475 id: 1,
476 content: "path/one content".to_string(),
477 history: vec![],
478 selections: vec![],
479 }),
480 };
481
482 // on the server, respond to two requests for each client
483 let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
484 let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
485 let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
486 smol::spawn({
487 let request1 = request1.clone();
488 let request2 = request2.clone();
489 let request3 = request3.clone();
490 let request4 = request4.clone();
491 let response1 = response1.clone();
492 let response2 = response2.clone();
493 let response3 = response3.clone();
494 let response4 = response4.clone();
495 async move {
496 let msg = auth_rx.recv().await.unwrap();
497 assert_eq!(msg.payload, request1);
498 server
499 .respond(msg.receipt(), response1.clone())
500 .await
501 .unwrap();
502
503 let msg = auth_rx.recv().await.unwrap();
504 assert_eq!(msg.payload, request2.clone());
505 server
506 .respond(msg.receipt(), response2.clone())
507 .await
508 .unwrap();
509
510 let msg = open_buffer_rx.recv().await.unwrap();
511 assert_eq!(msg.payload, request3.clone());
512 server
513 .respond(msg.receipt(), response3.clone())
514 .await
515 .unwrap();
516
517 let msg = open_buffer_rx.recv().await.unwrap();
518 assert_eq!(msg.payload, request4.clone());
519 server
520 .respond(msg.receipt(), response4.clone())
521 .await
522 .unwrap();
523
524 server_done_tx.send(()).await.unwrap();
525 }
526 })
527 .detach();
528
529 assert_eq!(
530 client1.request(client1_conn_id, request1).await.unwrap(),
531 response1
532 );
533 assert_eq!(
534 client2.request(client2_conn_id, request2).await.unwrap(),
535 response2
536 );
537 assert_eq!(
538 client2.request(client2_conn_id, request3).await.unwrap(),
539 response3
540 );
541 assert_eq!(
542 client1.request(client1_conn_id, request4).await.unwrap(),
543 response4
544 );
545
546 client1.disconnect(client1_conn_id).await;
547 client2.disconnect(client1_conn_id).await;
548
549 server_done_rx.recv().await.unwrap();
550 });
551 }
552
553 #[test]
554 fn test_disconnect() {
555 smol::block_on(async move {
556 let socket_dir_path = TempDir::new("drop-client").unwrap();
557 let socket_path = socket_dir_path.path().join(".sock");
558 let listener = UnixListener::bind(&socket_path).unwrap();
559 let client_conn = UnixStream::connect(&socket_path).await.unwrap();
560 let (mut server_conn, _) = listener.accept().await.unwrap();
561
562 let client = Peer::new();
563 let (connection_id, handler) = client.add_connection(client_conn).await;
564 let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
565 postage::barrier::channel();
566 smol::spawn(async move {
567 handler.run().await.ok();
568 incoming_messages_ended_tx.send(()).await.unwrap();
569 })
570 .detach();
571 client.disconnect(connection_id).await;
572
573 incoming_messages_ended_rx.recv().await;
574
575 let err = server_conn.write(&[]).await.unwrap_err();
576 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
577 });
578 }
579
580 #[test]
581 fn test_io_error() {
582 smol::block_on(async move {
583 let socket_dir_path = TempDir::new("io-error").unwrap();
584 let socket_path = socket_dir_path.path().join(".sock");
585 let _listener = UnixListener::bind(&socket_path).unwrap();
586 let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
587 client_conn.close().await.unwrap();
588
589 let client = Peer::new();
590 let (connection_id, handler) = client.add_connection(client_conn).await;
591 smol::spawn(handler.run()).detach();
592
593 let err = client
594 .request(
595 connection_id,
596 proto::Auth {
597 user_id: 42,
598 access_token: "token".to_string(),
599 },
600 )
601 .await
602 .unwrap_err();
603 assert_eq!(err.to_string(), "connection was closed");
604 });
605 }
606}