1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
2use anyhow::{anyhow, Context, Result};
3use async_lock::{Mutex, RwLock};
4use futures::{future::BoxFuture, AsyncRead, AsyncWrite, FutureExt};
5use postage::{
6 mpsc,
7 prelude::{Sink, Stream},
8};
9use std::{
10 any::TypeId,
11 collections::{HashMap, HashSet},
12 fmt,
13 future::Future,
14 marker::PhantomData,
15 sync::{
16 atomic::{self, AtomicU32},
17 Arc,
18 },
19};
20
21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
22pub struct ConnectionId(pub u32);
23
24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
25pub struct PeerId(pub u32);
26
27type MessageHandler = Box<
28 dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
29>;
30
31#[derive(Clone, Copy)]
32pub struct Receipt<T> {
33 sender_id: ConnectionId,
34 message_id: u32,
35 payload_type: PhantomData<T>,
36}
37
38pub struct TypedEnvelope<T> {
39 pub sender_id: ConnectionId,
40 original_sender_id: Option<PeerId>,
41 pub message_id: u32,
42 pub payload: T,
43}
44
45impl<T> TypedEnvelope<T> {
46 pub fn original_sender_id(&self) -> Result<PeerId> {
47 self.original_sender_id
48 .ok_or_else(|| anyhow!("missing original_sender_id"))
49 }
50}
51
52impl<T: RequestMessage> TypedEnvelope<T> {
53 pub fn receipt(&self) -> Receipt<T> {
54 Receipt {
55 sender_id: self.sender_id,
56 message_id: self.message_id,
57 payload_type: PhantomData,
58 }
59 }
60}
61
62pub struct Peer {
63 connections: RwLock<HashMap<ConnectionId, Connection>>,
64 message_handlers: RwLock<Vec<MessageHandler>>,
65 handler_types: Mutex<HashSet<TypeId>>,
66 next_connection_id: AtomicU32,
67}
68
69#[derive(Clone)]
70struct Connection {
71 outgoing_tx: mpsc::Sender<proto::Envelope>,
72 next_message_id: Arc<AtomicU32>,
73 response_channels: ResponseChannels,
74}
75
76pub struct ConnectionHandler<Conn> {
77 peer: Arc<Peer>,
78 connection_id: ConnectionId,
79 response_channels: ResponseChannels,
80 outgoing_rx: mpsc::Receiver<proto::Envelope>,
81 reader: MessageStream<Conn>,
82 writer: MessageStream<Conn>,
83}
84
85type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
86
87impl Peer {
88 pub fn new() -> Arc<Self> {
89 Arc::new(Self {
90 connections: Default::default(),
91 message_handlers: Default::default(),
92 handler_types: Default::default(),
93 next_connection_id: Default::default(),
94 })
95 }
96
97 pub async fn add_message_handler<T: EnvelopedMessage>(
98 &self,
99 ) -> mpsc::Receiver<TypedEnvelope<T>> {
100 if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
101 panic!("duplicate handler type");
102 }
103
104 let (tx, rx) = mpsc::channel(256);
105 self.message_handlers
106 .write()
107 .await
108 .push(Box::new(move |envelope, connection_id| {
109 if envelope.as_ref().map_or(false, T::matches_envelope) {
110 let envelope = Option::take(envelope).unwrap();
111 let mut tx = tx.clone();
112 Some(
113 async move {
114 tx.send(TypedEnvelope {
115 sender_id: connection_id,
116 original_sender_id: envelope.original_sender_id.map(PeerId),
117 message_id: envelope.id,
118 payload: T::from_envelope(envelope).unwrap(),
119 })
120 .await
121 .is_err()
122 }
123 .boxed(),
124 )
125 } else {
126 None
127 }
128 }));
129 rx
130 }
131
132 pub async fn add_connection<Conn>(
133 self: &Arc<Self>,
134 conn: Conn,
135 ) -> (ConnectionId, ConnectionHandler<Conn>)
136 where
137 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
138 {
139 let connection_id = ConnectionId(
140 self.next_connection_id
141 .fetch_add(1, atomic::Ordering::SeqCst),
142 );
143 let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
144 let connection = Connection {
145 outgoing_tx,
146 next_message_id: Default::default(),
147 response_channels: Default::default(),
148 };
149 let handler = ConnectionHandler {
150 peer: self.clone(),
151 connection_id,
152 response_channels: connection.response_channels.clone(),
153 outgoing_rx,
154 reader: MessageStream::new(conn.clone()),
155 writer: MessageStream::new(conn),
156 };
157 self.connections
158 .write()
159 .await
160 .insert(connection_id, connection);
161 (connection_id, handler)
162 }
163
164 pub async fn disconnect(&self, connection_id: ConnectionId) {
165 self.connections.write().await.remove(&connection_id);
166 }
167
168 pub async fn reset(&self) {
169 self.connections.write().await.clear();
170 self.handler_types.lock().await.clear();
171 self.message_handlers.write().await.clear();
172 }
173
174 pub fn request<T: RequestMessage>(
175 self: &Arc<Self>,
176 receiver_id: ConnectionId,
177 request: T,
178 ) -> impl Future<Output = Result<T::Response>> {
179 self.request_internal(None, receiver_id, request)
180 }
181
182 pub fn forward_request<T: RequestMessage>(
183 self: &Arc<Self>,
184 sender_id: ConnectionId,
185 receiver_id: ConnectionId,
186 request: T,
187 ) -> impl Future<Output = Result<T::Response>> {
188 self.request_internal(Some(sender_id), receiver_id, request)
189 }
190
191 pub fn request_internal<T: RequestMessage>(
192 self: &Arc<Self>,
193 original_sender_id: Option<ConnectionId>,
194 receiver_id: ConnectionId,
195 request: T,
196 ) -> impl Future<Output = Result<T::Response>> {
197 let this = self.clone();
198 let (tx, mut rx) = mpsc::channel(1);
199 async move {
200 let mut connection = this.connection(receiver_id).await?;
201 let message_id = connection
202 .next_message_id
203 .fetch_add(1, atomic::Ordering::SeqCst);
204 connection
205 .response_channels
206 .lock()
207 .await
208 .insert(message_id, tx);
209 connection
210 .outgoing_tx
211 .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
212 .await?;
213 let response = rx
214 .recv()
215 .await
216 .ok_or_else(|| anyhow!("connection was closed"))?;
217 T::Response::from_envelope(response)
218 .ok_or_else(|| anyhow!("received response of the wrong type"))
219 }
220 }
221
222 pub fn send<T: EnvelopedMessage>(
223 self: &Arc<Self>,
224 receiver_id: ConnectionId,
225 message: T,
226 ) -> impl Future<Output = Result<()>> {
227 let this = self.clone();
228 async move {
229 let mut connection = this.connection(receiver_id).await?;
230 let message_id = connection
231 .next_message_id
232 .fetch_add(1, atomic::Ordering::SeqCst);
233 connection
234 .outgoing_tx
235 .send(message.into_envelope(message_id, None, None))
236 .await?;
237 Ok(())
238 }
239 }
240
241 pub fn forward_send<T: EnvelopedMessage>(
242 self: &Arc<Self>,
243 sender_id: ConnectionId,
244 receiver_id: ConnectionId,
245 message: T,
246 ) -> impl Future<Output = Result<()>> {
247 let this = self.clone();
248 async move {
249 let mut connection = this.connection(receiver_id).await?;
250 let message_id = connection
251 .next_message_id
252 .fetch_add(1, atomic::Ordering::SeqCst);
253 connection
254 .outgoing_tx
255 .send(message.into_envelope(message_id, None, Some(sender_id.0)))
256 .await?;
257 Ok(())
258 }
259 }
260
261 pub fn respond<T: RequestMessage>(
262 self: &Arc<Self>,
263 receipt: Receipt<T>,
264 response: T::Response,
265 ) -> impl Future<Output = Result<()>> {
266 let this = self.clone();
267 async move {
268 let mut connection = this.connection(receipt.sender_id).await?;
269 let message_id = connection
270 .next_message_id
271 .fetch_add(1, atomic::Ordering::SeqCst);
272 connection
273 .outgoing_tx
274 .send(response.into_envelope(message_id, Some(receipt.message_id), None))
275 .await?;
276 Ok(())
277 }
278 }
279
280 fn connection(
281 self: &Arc<Self>,
282 connection_id: ConnectionId,
283 ) -> impl Future<Output = Result<Connection>> {
284 let this = self.clone();
285 async move {
286 let connections = this.connections.read().await;
287 let connection = connections
288 .get(&connection_id)
289 .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
290 Ok(connection.clone())
291 }
292 }
293}
294
295impl<Conn> ConnectionHandler<Conn>
296where
297 Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
298{
299 pub async fn run(mut self) -> Result<()> {
300 loop {
301 let read_message = self.reader.read_message().fuse();
302 futures::pin_mut!(read_message);
303 loop {
304 futures::select! {
305 incoming = read_message => match incoming {
306 Ok(incoming) => {
307 Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
308 break;
309 }
310 Err(error) => {
311 self.response_channels.lock().await.clear();
312 Err(error).context("received invalid RPC message")?;
313 }
314 },
315 outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
316 Some(outgoing) => {
317 if let Err(result) = self.writer.write_message(&outgoing).await {
318 self.response_channels.lock().await.clear();
319 Err(result).context("failed to write RPC message")?;
320 }
321 }
322 None => return Ok(()),
323 }
324 }
325 }
326 }
327 }
328
329 pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
330 let envelope = self.reader.read_message().await?;
331 let original_sender_id = envelope.original_sender_id;
332 let message_id = envelope.id;
333 let payload =
334 M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
335 Ok(TypedEnvelope {
336 sender_id: self.connection_id,
337 original_sender_id: original_sender_id.map(PeerId),
338 message_id,
339 payload,
340 })
341 }
342
343 async fn handle_incoming_message(
344 message: proto::Envelope,
345 peer: &Arc<Peer>,
346 connection_id: ConnectionId,
347 response_channels: &ResponseChannels,
348 ) {
349 if let Some(responding_to) = message.responding_to {
350 let channel = response_channels.lock().await.remove(&responding_to);
351 if let Some(mut tx) = channel {
352 tx.send(message).await.ok();
353 } else {
354 log::warn!("received RPC response to unknown request {}", responding_to);
355 }
356 } else {
357 let mut envelope = Some(message);
358 let mut handler_index = None;
359 let mut handler_was_dropped = false;
360 for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
361 if let Some(future) = handler(&mut envelope, connection_id) {
362 handler_was_dropped = future.await;
363 handler_index = Some(i);
364 break;
365 }
366 }
367
368 if let Some(handler_index) = handler_index {
369 if handler_was_dropped {
370 drop(peer.message_handlers.write().await.remove(handler_index));
371 }
372 } else {
373 log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
374 }
375 }
376 }
377}
378
379impl fmt::Display for ConnectionId {
380 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381 self.0.fmt(f)
382 }
383}
384
385impl fmt::Display for PeerId {
386 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387 self.0.fmt(f)
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use smol::{
395 io::AsyncWriteExt,
396 net::unix::{UnixListener, UnixStream},
397 };
398 use std::io;
399 use tempdir::TempDir;
400
401 #[test]
402 fn test_request_response() {
403 smol::block_on(async move {
404 // create socket
405 let socket_dir_path = TempDir::new("test-request-response").unwrap();
406 let socket_path = socket_dir_path.path().join("test.sock");
407 let listener = UnixListener::bind(&socket_path).unwrap();
408
409 // create 2 clients connected to 1 server
410 let server = Peer::new();
411 let client1 = Peer::new();
412 let client2 = Peer::new();
413 let (client1_conn_id, task1) = client1
414 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
415 .await;
416 let (client2_conn_id, task2) = client2
417 .add_connection(UnixStream::connect(&socket_path).await.unwrap())
418 .await;
419 let (_, task3) = server
420 .add_connection(listener.accept().await.unwrap().0)
421 .await;
422 let (_, task4) = server
423 .add_connection(listener.accept().await.unwrap().0)
424 .await;
425 smol::spawn(task1.run()).detach();
426 smol::spawn(task2.run()).detach();
427 smol::spawn(task3.run()).detach();
428 smol::spawn(task4.run()).detach();
429
430 // define the expected requests and responses
431 let request1 = proto::Auth {
432 user_id: 1,
433 access_token: "token-1".to_string(),
434 };
435 let response1 = proto::AuthResponse {
436 credentials_valid: true,
437 };
438 let request2 = proto::Auth {
439 user_id: 2,
440 access_token: "token-2".to_string(),
441 };
442 let response2 = proto::AuthResponse {
443 credentials_valid: false,
444 };
445 let request3 = proto::OpenBuffer {
446 worktree_id: 1,
447 path: "path/two".to_string(),
448 };
449 let response3 = proto::OpenBufferResponse {
450 buffer: Some(proto::Buffer {
451 id: 2,
452 content: "path/two content".to_string(),
453 history: vec![],
454 selections: vec![],
455 }),
456 };
457 let request4 = proto::OpenBuffer {
458 worktree_id: 2,
459 path: "path/one".to_string(),
460 };
461 let response4 = proto::OpenBufferResponse {
462 buffer: Some(proto::Buffer {
463 id: 1,
464 content: "path/one content".to_string(),
465 history: vec![],
466 selections: vec![],
467 }),
468 };
469
470 // on the server, respond to two requests for each client
471 let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
472 let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
473 let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
474 smol::spawn({
475 let request1 = request1.clone();
476 let request2 = request2.clone();
477 let request3 = request3.clone();
478 let request4 = request4.clone();
479 let response1 = response1.clone();
480 let response2 = response2.clone();
481 let response3 = response3.clone();
482 let response4 = response4.clone();
483 async move {
484 let msg = auth_rx.recv().await.unwrap();
485 assert_eq!(msg.payload, request1);
486 server
487 .respond(msg.receipt(), response1.clone())
488 .await
489 .unwrap();
490
491 let msg = auth_rx.recv().await.unwrap();
492 assert_eq!(msg.payload, request2.clone());
493 server
494 .respond(msg.receipt(), response2.clone())
495 .await
496 .unwrap();
497
498 let msg = open_buffer_rx.recv().await.unwrap();
499 assert_eq!(msg.payload, request3.clone());
500 server
501 .respond(msg.receipt(), response3.clone())
502 .await
503 .unwrap();
504
505 let msg = open_buffer_rx.recv().await.unwrap();
506 assert_eq!(msg.payload, request4.clone());
507 server
508 .respond(msg.receipt(), response4.clone())
509 .await
510 .unwrap();
511
512 server_done_tx.send(()).await.unwrap();
513 }
514 })
515 .detach();
516
517 assert_eq!(
518 client1.request(client1_conn_id, request1).await.unwrap(),
519 response1
520 );
521 assert_eq!(
522 client2.request(client2_conn_id, request2).await.unwrap(),
523 response2
524 );
525 assert_eq!(
526 client2.request(client2_conn_id, request3).await.unwrap(),
527 response3
528 );
529 assert_eq!(
530 client1.request(client1_conn_id, request4).await.unwrap(),
531 response4
532 );
533
534 client1.disconnect(client1_conn_id).await;
535 client2.disconnect(client1_conn_id).await;
536
537 server_done_rx.recv().await.unwrap();
538 });
539 }
540
541 #[test]
542 fn test_disconnect() {
543 smol::block_on(async move {
544 let socket_dir_path = TempDir::new("drop-client").unwrap();
545 let socket_path = socket_dir_path.path().join(".sock");
546 let listener = UnixListener::bind(&socket_path).unwrap();
547 let client_conn = UnixStream::connect(&socket_path).await.unwrap();
548 let (mut server_conn, _) = listener.accept().await.unwrap();
549
550 let client = Peer::new();
551 let (connection_id, handler) = client.add_connection(client_conn).await;
552 let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
553 postage::barrier::channel();
554 smol::spawn(async move {
555 handler.run().await.ok();
556 incoming_messages_ended_tx.send(()).await.unwrap();
557 })
558 .detach();
559 client.disconnect(connection_id).await;
560
561 incoming_messages_ended_rx.recv().await;
562
563 let err = server_conn.write(&[]).await.unwrap_err();
564 assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
565 });
566 }
567
568 #[test]
569 fn test_io_error() {
570 smol::block_on(async move {
571 let socket_dir_path = TempDir::new("io-error").unwrap();
572 let socket_path = socket_dir_path.path().join(".sock");
573 let _listener = UnixListener::bind(&socket_path).unwrap();
574 let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
575 client_conn.close().await.unwrap();
576
577 let client = Peer::new();
578 let (connection_id, handler) = client.add_connection(client_conn).await;
579 smol::spawn(handler.run()).detach();
580
581 let err = client
582 .request(
583 connection_id,
584 proto::Auth {
585 user_id: 42,
586 access_token: "token".to_string(),
587 },
588 )
589 .await
590 .unwrap_err();
591 assert_eq!(err.to_string(), "connection was closed");
592 });
593 }
594}