peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{
  5    future::{BoxFuture, Either},
  6    AsyncRead, AsyncWrite, FutureExt,
  7};
  8use postage::{
  9    barrier, mpsc, oneshot,
 10    prelude::{Sink, Stream},
 11};
 12use std::{
 13    any::TypeId,
 14    collections::{HashMap, HashSet},
 15    fmt,
 16    future::Future,
 17    pin::Pin,
 18    sync::{
 19        atomic::{self, AtomicU32},
 20        Arc,
 21    },
 22};
 23
 24type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
 25type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
 26
 27#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 28pub struct ConnectionId(u32);
 29
 30struct Connection {
 31    writer: Mutex<MessageStream<BoxedWriter>>,
 32    reader: Mutex<MessageStream<BoxedReader>>,
 33    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
 34    next_message_id: AtomicU32,
 35}
 36
 37type MessageHandler = Box<
 38    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 39>;
 40
 41pub struct TypedEnvelope<T> {
 42    pub id: u32,
 43    pub connection_id: ConnectionId,
 44    pub payload: T,
 45}
 46
 47pub struct Peer {
 48    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
 49    connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
 50    message_handlers: RwLock<Vec<MessageHandler>>,
 51    handler_types: Mutex<HashSet<TypeId>>,
 52    next_connection_id: AtomicU32,
 53}
 54
 55impl Peer {
 56    pub fn new() -> Arc<Self> {
 57        Arc::new(Self {
 58            connections: Default::default(),
 59            connection_close_barriers: Default::default(),
 60            message_handlers: Default::default(),
 61            handler_types: Default::default(),
 62            next_connection_id: Default::default(),
 63        })
 64    }
 65
 66    pub async fn add_message_handler<T: EnvelopedMessage>(
 67        &self,
 68    ) -> mpsc::Receiver<TypedEnvelope<T>> {
 69        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
 70            panic!("duplicate handler type");
 71        }
 72
 73        let (tx, rx) = mpsc::channel(256);
 74        self.message_handlers
 75            .write()
 76            .await
 77            .push(Box::new(move |envelope, connection_id| {
 78                if envelope.as_ref().map_or(false, T::matches_envelope) {
 79                    let envelope = Option::take(envelope).unwrap();
 80                    let mut tx = tx.clone();
 81                    Some(
 82                        async move {
 83                            tx.send(TypedEnvelope {
 84                                id: envelope.id,
 85                                connection_id,
 86                                payload: T::from_envelope(envelope).unwrap(),
 87                            })
 88                            .await
 89                            .is_err()
 90                        }
 91                        .boxed(),
 92                    )
 93                } else {
 94                    None
 95                }
 96            }));
 97        rx
 98    }
 99
100    pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
101    where
102        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
103    {
104        let connection_id = ConnectionId(
105            self.next_connection_id
106                .fetch_add(1, atomic::Ordering::SeqCst),
107        );
108        self.connections.write().await.insert(
109            connection_id,
110            Arc::new(Connection {
111                reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
112                writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
113                response_channels: Default::default(),
114                next_message_id: Default::default(),
115            }),
116        );
117        connection_id
118    }
119
120    pub async fn disconnect(&self, connection_id: ConnectionId) {
121        self.connections.write().await.remove(&connection_id);
122        self.connection_close_barriers
123            .write()
124            .await
125            .remove(&connection_id);
126    }
127
128    pub fn handle_messages(
129        self: &Arc<Self>,
130        connection_id: ConnectionId,
131    ) -> impl Future<Output = Result<()>> + 'static {
132        let (close_tx, mut close_rx) = barrier::channel();
133        let this = self.clone();
134        async move {
135            this.connection_close_barriers
136                .write()
137                .await
138                .insert(connection_id, close_tx);
139            let connection = this.connection(connection_id).await?;
140            let closed = close_rx.recv();
141            futures::pin_mut!(closed);
142
143            loop {
144                let mut reader = connection.reader.lock().await;
145                let read_message = reader.read_message();
146                futures::pin_mut!(read_message);
147
148                match futures::future::select(read_message, &mut closed).await {
149                    Either::Left((Ok(incoming), _)) => {
150                        if let Some(responding_to) = incoming.responding_to {
151                            let channel = connection
152                                .response_channels
153                                .lock()
154                                .await
155                                .remove(&responding_to);
156                            if let Some(mut tx) = channel {
157                                tx.send(incoming).await.ok();
158                            } else {
159                                log::warn!(
160                                    "received RPC response to unknown request {}",
161                                    responding_to
162                                );
163                            }
164                        } else {
165                            let mut envelope = Some(incoming);
166                            let mut handler_index = None;
167                            let mut handler_was_dropped = false;
168                            for (i, handler) in
169                                this.message_handlers.read().await.iter().enumerate()
170                            {
171                                if let Some(future) = handler(&mut envelope, connection_id) {
172                                    handler_was_dropped = future.await;
173                                    handler_index = Some(i);
174                                    break;
175                                }
176                            }
177
178                            if let Some(handler_index) = handler_index {
179                                if handler_was_dropped {
180                                    drop(this.message_handlers.write().await.remove(handler_index));
181                                }
182                            } else {
183                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
184                            }
185                        }
186                    }
187                    Either::Left((Err(error), _)) => {
188                        log::warn!("received invalid RPC message: {}", error);
189                        Err(error)?;
190                    }
191                    Either::Right(_) => return Ok(()),
192                }
193            }
194        }
195    }
196
197    pub async fn receive<M: EnvelopedMessage>(
198        self: &Arc<Self>,
199        connection_id: ConnectionId,
200    ) -> Result<TypedEnvelope<M>> {
201        let connection = self.connection(connection_id).await?;
202        let envelope = connection.reader.lock().await.read_message().await?;
203        let id = envelope.id;
204        let payload =
205            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
206        Ok(TypedEnvelope {
207            id,
208            connection_id,
209            payload,
210        })
211    }
212
213    pub fn request<T: RequestMessage>(
214        self: &Arc<Self>,
215        connection_id: ConnectionId,
216        req: T,
217    ) -> impl Future<Output = Result<T::Response>> {
218        let this = self.clone();
219        let (tx, mut rx) = oneshot::channel();
220        async move {
221            let connection = this.connection(connection_id).await?;
222            let message_id = connection
223                .next_message_id
224                .fetch_add(1, atomic::Ordering::SeqCst);
225            connection
226                .response_channels
227                .lock()
228                .await
229                .insert(message_id, tx);
230            connection
231                .writer
232                .lock()
233                .await
234                .write_message(&req.into_envelope(message_id, None))
235                .await?;
236            let response = rx
237                .recv()
238                .await
239                .expect("response channel was unexpectedly dropped");
240            T::Response::from_envelope(response)
241                .ok_or_else(|| anyhow!("received response of the wrong type"))
242        }
243    }
244
245    pub fn send<T: EnvelopedMessage>(
246        self: &Arc<Self>,
247        connection_id: ConnectionId,
248        message: T,
249    ) -> impl Future<Output = Result<()>> {
250        let this = self.clone();
251        async move {
252            let connection = this.connection(connection_id).await?;
253            let message_id = connection
254                .next_message_id
255                .fetch_add(1, atomic::Ordering::SeqCst);
256            connection
257                .writer
258                .lock()
259                .await
260                .write_message(&message.into_envelope(message_id, None))
261                .await?;
262            Ok(())
263        }
264    }
265
266    pub fn respond<T: RequestMessage>(
267        self: &Arc<Self>,
268        request: TypedEnvelope<T>,
269        response: T::Response,
270    ) -> impl Future<Output = Result<()>> {
271        let this = self.clone();
272        async move {
273            let connection = this.connection(request.connection_id).await?;
274            let message_id = connection
275                .next_message_id
276                .fetch_add(1, atomic::Ordering::SeqCst);
277            connection
278                .writer
279                .lock()
280                .await
281                .write_message(&response.into_envelope(message_id, Some(request.id)))
282                .await?;
283            Ok(())
284        }
285    }
286
287    async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
288        Ok(self
289            .connections
290            .read()
291            .await
292            .get(&id)
293            .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
294            .clone())
295    }
296}
297
298impl fmt::Display for ConnectionId {
299    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
300        self.0.fmt(f)
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use smol::{
308        io::AsyncWriteExt,
309        net::unix::{UnixListener, UnixStream},
310    };
311    use std::io;
312    use tempdir::TempDir;
313
314    #[test]
315    fn test_request_response() {
316        smol::block_on(async move {
317            // create socket
318            let socket_dir_path = TempDir::new("test-request-response").unwrap();
319            let socket_path = socket_dir_path.path().join("test.sock");
320            let listener = UnixListener::bind(&socket_path).unwrap();
321
322            // create 2 clients connected to 1 server
323            let server = Peer::new();
324            let client1 = Peer::new();
325            let client2 = Peer::new();
326            let client1_conn_id = client1
327                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
328                .await;
329            let client2_conn_id = client2
330                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
331                .await;
332            let server_conn_id1 = server
333                .add_connection(listener.accept().await.unwrap().0)
334                .await;
335            let server_conn_id2 = server
336                .add_connection(listener.accept().await.unwrap().0)
337                .await;
338            smol::spawn(client1.handle_messages(client1_conn_id)).detach();
339            smol::spawn(client2.handle_messages(client2_conn_id)).detach();
340            smol::spawn(server.handle_messages(server_conn_id1)).detach();
341            smol::spawn(server.handle_messages(server_conn_id2)).detach();
342
343            // define the expected requests and responses
344            let request1 = proto::Auth {
345                user_id: 1,
346                access_token: "token-1".to_string(),
347            };
348            let response1 = proto::AuthResponse {
349                credentials_valid: true,
350            };
351            let request2 = proto::Auth {
352                user_id: 2,
353                access_token: "token-2".to_string(),
354            };
355            let response2 = proto::AuthResponse {
356                credentials_valid: false,
357            };
358            let request3 = proto::OpenBuffer {
359                worktree_id: 102,
360                path: "path/two".to_string(),
361            };
362            let response3 = proto::OpenBufferResponse {
363                buffer: Some(proto::Buffer {
364                    id: 1001,
365                    path: "path/two".to_string(),
366                    content: "path/two content".to_string(),
367                    history: vec![],
368                }),
369            };
370            let request4 = proto::OpenBuffer {
371                worktree_id: 101,
372                path: "path/one".to_string(),
373            };
374            let response4 = proto::OpenBufferResponse {
375                buffer: Some(proto::Buffer {
376                    id: 1002,
377                    path: "path/one".to_string(),
378                    content: "path/one content".to_string(),
379                    history: vec![],
380                }),
381            };
382
383            // on the server, respond to two requests for each client
384            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
385            let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
386            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
387            smol::spawn({
388                let request1 = request1.clone();
389                let request2 = request2.clone();
390                let request3 = request3.clone();
391                let request4 = request4.clone();
392                let response1 = response1.clone();
393                let response2 = response2.clone();
394                let response3 = response3.clone();
395                let response4 = response4.clone();
396                async move {
397                    let msg = auth_rx.recv().await.unwrap();
398                    assert_eq!(msg.payload, request1);
399                    server.respond(msg, response1.clone()).await.unwrap();
400
401                    let msg = auth_rx.recv().await.unwrap();
402                    assert_eq!(msg.payload, request2.clone());
403                    server.respond(msg, response2.clone()).await.unwrap();
404
405                    let msg = open_buffer_rx.recv().await.unwrap();
406                    assert_eq!(msg.payload, request3.clone());
407                    server.respond(msg, response3.clone()).await.unwrap();
408
409                    let msg = open_buffer_rx.recv().await.unwrap();
410                    assert_eq!(msg.payload, request4.clone());
411                    server.respond(msg, response4.clone()).await.unwrap();
412
413                    server_done_tx.send(()).await.unwrap();
414                }
415            })
416            .detach();
417
418            assert_eq!(
419                client1.request(client1_conn_id, request1).await.unwrap(),
420                response1
421            );
422            assert_eq!(
423                client2.request(client2_conn_id, request2).await.unwrap(),
424                response2
425            );
426            assert_eq!(
427                client2.request(client2_conn_id, request3).await.unwrap(),
428                response3
429            );
430            assert_eq!(
431                client1.request(client1_conn_id, request4).await.unwrap(),
432                response4
433            );
434
435            client1.disconnect(client1_conn_id).await;
436            client2.disconnect(client1_conn_id).await;
437
438            server_done_rx.recv().await.unwrap();
439        });
440    }
441
442    #[test]
443    fn test_disconnect() {
444        smol::block_on(async move {
445            let socket_dir_path = TempDir::new("drop-client").unwrap();
446            let socket_path = socket_dir_path.path().join(".sock");
447            let listener = UnixListener::bind(&socket_path).unwrap();
448            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
449            let (mut server_conn, _) = listener.accept().await.unwrap();
450
451            let client = Peer::new();
452            let connection_id = client.add_connection(client_conn).await;
453            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
454                barrier::channel();
455            let handle_messages = client.handle_messages(connection_id);
456            smol::spawn(async move {
457                handle_messages.await.ok();
458                incoming_messages_ended_tx.send(()).await.unwrap();
459            })
460            .detach();
461            client.disconnect(connection_id).await;
462
463            incoming_messages_ended_rx.recv().await;
464
465            let err = server_conn.write(&[]).await.unwrap_err();
466            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
467        });
468    }
469
470    #[test]
471    fn test_io_error() {
472        smol::block_on(async move {
473            let socket_dir_path = TempDir::new("io-error").unwrap();
474            let socket_path = socket_dir_path.path().join(".sock");
475            let _listener = UnixListener::bind(&socket_path).unwrap();
476            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
477            client_conn.close().await.unwrap();
478
479            let client = Peer::new();
480            let connection_id = client.add_connection(client_conn).await;
481            smol::spawn(client.handle_messages(connection_id)).detach();
482
483            let err = client
484                .request(
485                    connection_id,
486                    proto::Auth {
487                        user_id: 42,
488                        access_token: "token".to_string(),
489                    },
490                )
491                .await
492                .unwrap_err();
493            assert_eq!(
494                err.downcast_ref::<io::Error>().unwrap().kind(),
495                io::ErrorKind::BrokenPipe
496            );
497        });
498    }
499}