peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{
  5    future::{BoxFuture, Either},
  6    AsyncRead, AsyncWrite, FutureExt,
  7};
  8use postage::{
  9    barrier, mpsc, oneshot,
 10    prelude::{Sink, Stream},
 11};
 12use std::{
 13    any::TypeId,
 14    collections::{HashMap, HashSet},
 15    fmt,
 16    future::Future,
 17    pin::Pin,
 18    sync::{
 19        atomic::{self, AtomicU32},
 20        Arc,
 21    },
 22};
 23
 24type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
 25type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
 26
 27#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 28pub struct ConnectionId(u32);
 29
 30struct Connection {
 31    writer: Mutex<MessageStream<BoxedWriter>>,
 32    reader: Mutex<MessageStream<BoxedReader>>,
 33    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
 34    next_message_id: AtomicU32,
 35}
 36
 37type MessageHandler = Box<
 38    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 39>;
 40
 41pub struct TypedEnvelope<T> {
 42    pub id: u32,
 43    pub connection_id: ConnectionId,
 44    pub payload: T,
 45}
 46
 47pub struct Peer {
 48    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
 49    connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
 50    message_handlers: RwLock<Vec<MessageHandler>>,
 51    handler_types: Mutex<HashSet<TypeId>>,
 52    next_connection_id: AtomicU32,
 53}
 54
 55impl Peer {
 56    pub fn new() -> Arc<Self> {
 57        Arc::new(Self {
 58            connections: Default::default(),
 59            connection_close_barriers: Default::default(),
 60            message_handlers: Default::default(),
 61            handler_types: Default::default(),
 62            next_connection_id: Default::default(),
 63        })
 64    }
 65
 66    pub async fn add_message_handler<T: EnvelopedMessage>(
 67        &self,
 68    ) -> mpsc::Receiver<TypedEnvelope<T>> {
 69        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
 70            panic!("duplicate handler type");
 71        }
 72
 73        let (tx, rx) = mpsc::channel(256);
 74        self.message_handlers
 75            .write()
 76            .await
 77            .push(Box::new(move |envelope, connection_id| {
 78                if envelope.as_ref().map_or(false, T::matches_envelope) {
 79                    let envelope = Option::take(envelope).unwrap();
 80                    let mut tx = tx.clone();
 81                    Some(
 82                        async move {
 83                            tx.send(TypedEnvelope {
 84                                id: envelope.id,
 85                                connection_id,
 86                                payload: T::from_envelope(envelope).unwrap(),
 87                            })
 88                            .await
 89                            .is_err()
 90                        }
 91                        .boxed(),
 92                    )
 93                } else {
 94                    None
 95                }
 96            }));
 97        rx
 98    }
 99
100    pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
101    where
102        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
103    {
104        let connection_id = ConnectionId(
105            self.next_connection_id
106                .fetch_add(1, atomic::Ordering::SeqCst),
107        );
108        self.connections.write().await.insert(
109            connection_id,
110            Arc::new(Connection {
111                reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
112                writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
113                response_channels: Default::default(),
114                next_message_id: Default::default(),
115            }),
116        );
117        connection_id
118    }
119
120    pub async fn disconnect(&self, connection_id: ConnectionId) {
121        self.connections.write().await.remove(&connection_id);
122        self.connection_close_barriers
123            .write()
124            .await
125            .remove(&connection_id);
126    }
127
128    pub fn handle_messages(
129        self: &Arc<Self>,
130        connection_id: ConnectionId,
131    ) -> impl Future<Output = Result<()>> + 'static {
132        let (close_tx, mut close_rx) = barrier::channel();
133        let this = self.clone();
134        async move {
135            this.connection_close_barriers
136                .write()
137                .await
138                .insert(connection_id, close_tx);
139            let connection = this.connection(connection_id).await?;
140            let closed = close_rx.recv();
141            futures::pin_mut!(closed);
142
143            loop {
144                let mut reader = connection.reader.lock().await;
145                let read_message = reader.read_message();
146                futures::pin_mut!(read_message);
147
148                match futures::future::select(read_message, &mut closed).await {
149                    Either::Left((Ok(incoming), _)) => {
150                        if let Some(responding_to) = incoming.responding_to {
151                            let channel = connection
152                                .response_channels
153                                .lock()
154                                .await
155                                .remove(&responding_to);
156                            if let Some(mut tx) = channel {
157                                tx.send(incoming).await.ok();
158                            } else {
159                                log::warn!(
160                                    "received RPC response to unknown request {}",
161                                    responding_to
162                                );
163                            }
164                        } else {
165                            let mut envelope = Some(incoming);
166                            let mut handler_index = None;
167                            let mut handler_was_dropped = false;
168                            for (i, handler) in
169                                this.message_handlers.read().await.iter().enumerate()
170                            {
171                                if let Some(future) = handler(&mut envelope, connection_id) {
172                                    handler_was_dropped = future.await;
173                                    handler_index = Some(i);
174                                    break;
175                                }
176                            }
177
178                            if let Some(handler_index) = handler_index {
179                                if handler_was_dropped {
180                                    drop(this.message_handlers.write().await.remove(handler_index));
181                                }
182                            } else {
183                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
184                            }
185                        }
186                    }
187                    Either::Left((Err(error), _)) => {
188                        log::warn!("received invalid RPC message: {}", error);
189                        Err(error)?;
190                    }
191                    Either::Right(_) => return Ok(()),
192                }
193            }
194        }
195    }
196
197    pub async fn receive<M: EnvelopedMessage>(
198        self: &Arc<Self>,
199        connection_id: ConnectionId,
200    ) -> Result<TypedEnvelope<M>> {
201        let connection = self.connection(connection_id).await?;
202        let envelope = connection.reader.lock().await.read_message().await?;
203        let id = envelope.id;
204        let payload =
205            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
206        Ok(TypedEnvelope {
207            id,
208            connection_id,
209            payload,
210        })
211    }
212
213    pub fn request<T: RequestMessage>(
214        self: &Arc<Self>,
215        connection_id: ConnectionId,
216        req: T,
217    ) -> impl Future<Output = Result<T::Response>> {
218        let this = self.clone();
219        let (tx, mut rx) = oneshot::channel();
220        async move {
221            let connection = this.connection(connection_id).await?;
222            let message_id = connection
223                .next_message_id
224                .fetch_add(1, atomic::Ordering::SeqCst);
225            connection
226                .response_channels
227                .lock()
228                .await
229                .insert(message_id, tx);
230            connection
231                .writer
232                .lock()
233                .await
234                .write_message(&req.into_envelope(message_id, None))
235                .await?;
236            let response = rx
237                .recv()
238                .await
239                .expect("response channel was unexpectedly dropped");
240            T::Response::from_envelope(response)
241                .ok_or_else(|| anyhow!("received response of the wrong type"))
242        }
243    }
244
245    pub fn send<T: EnvelopedMessage>(
246        self: &Arc<Self>,
247        connection_id: ConnectionId,
248        message: T,
249    ) -> impl Future<Output = Result<()>> {
250        let this = self.clone();
251        async move {
252            let connection = this.connection(connection_id).await?;
253            let message_id = connection
254                .next_message_id
255                .fetch_add(1, atomic::Ordering::SeqCst);
256            connection
257                .writer
258                .lock()
259                .await
260                .write_message(&message.into_envelope(message_id, None))
261                .await?;
262            Ok(())
263        }
264    }
265
266    pub fn respond<T: RequestMessage>(
267        self: &Arc<Self>,
268        request: TypedEnvelope<T>,
269        response: T::Response,
270    ) -> impl Future<Output = Result<()>> {
271        let this = self.clone();
272        async move {
273            let connection = this.connection(request.connection_id).await?;
274            let message_id = connection
275                .next_message_id
276                .fetch_add(1, atomic::Ordering::SeqCst);
277            connection
278                .writer
279                .lock()
280                .await
281                .write_message(&response.into_envelope(message_id, Some(request.id)))
282                .await?;
283            Ok(())
284        }
285    }
286
287    async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
288        Ok(self
289            .connections
290            .read()
291            .await
292            .get(&id)
293            .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
294            .clone())
295    }
296}
297
298impl fmt::Display for ConnectionId {
299    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
300        self.0.fmt(f)
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use smol::{
308        io::AsyncWriteExt,
309        net::unix::{UnixListener, UnixStream},
310    };
311    use std::io;
312    use tempdir::TempDir;
313
314    #[test]
315    fn test_request_response() {
316        smol::block_on(async move {
317            // create socket
318            let socket_dir_path = TempDir::new("test-request-response").unwrap();
319            let socket_path = socket_dir_path.path().join("test.sock");
320            let listener = UnixListener::bind(&socket_path).unwrap();
321
322            // create 2 clients connected to 1 server
323            let server = Peer::new();
324            let client1 = Peer::new();
325            let client2 = Peer::new();
326            let client1_conn_id = client1
327                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
328                .await;
329            let client2_conn_id = client2
330                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
331                .await;
332            let server_conn_id1 = server
333                .add_connection(listener.accept().await.unwrap().0)
334                .await;
335            let server_conn_id2 = server
336                .add_connection(listener.accept().await.unwrap().0)
337                .await;
338            smol::spawn(client1.handle_messages(client1_conn_id)).detach();
339            smol::spawn(client2.handle_messages(client2_conn_id)).detach();
340            smol::spawn(server.handle_messages(server_conn_id1)).detach();
341            smol::spawn(server.handle_messages(server_conn_id2)).detach();
342
343            // define the expected requests and responses
344            let request1 = proto::OpenWorktree {
345                worktree_id: 101,
346                access_token: "first-worktree-access-token".to_string(),
347            };
348            let response1 = proto::OpenWorktreeResponse {
349                worktree: Some(proto::Worktree {
350                    paths: vec!["path/one".to_string()],
351                }),
352            };
353            let request2 = proto::OpenWorktree {
354                worktree_id: 102,
355                access_token: "second-worktree-access-token".to_string(),
356            };
357            let response2 = proto::OpenWorktreeResponse {
358                worktree: Some(proto::Worktree {
359                    paths: vec!["path/two".to_string(), "path/three".to_string()],
360                }),
361            };
362            let request3 = proto::OpenBuffer {
363                worktree_id: 102,
364                path: "path/two".to_string(),
365            };
366            let response3 = proto::OpenBufferResponse {
367                buffer: Some(proto::Buffer {
368                    id: 1001,
369                    path: "path/two".to_string(),
370                    content: "path/two content".to_string(),
371                    history: vec![],
372                }),
373            };
374            let request4 = proto::OpenBuffer {
375                worktree_id: 101,
376                path: "path/one".to_string(),
377            };
378            let response4 = proto::OpenBufferResponse {
379                buffer: Some(proto::Buffer {
380                    id: 1002,
381                    path: "path/one".to_string(),
382                    content: "path/one content".to_string(),
383                    history: vec![],
384                }),
385            };
386
387            // on the server, respond to two requests for each client
388            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
389            let mut open_worktree_rx = server.add_message_handler::<proto::OpenWorktree>().await;
390            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
391            smol::spawn({
392                let request1 = request1.clone();
393                let request2 = request2.clone();
394                let request3 = request3.clone();
395                let request4 = request4.clone();
396                let response1 = response1.clone();
397                let response2 = response2.clone();
398                let response3 = response3.clone();
399                let response4 = response4.clone();
400                async move {
401                    let msg = open_worktree_rx.recv().await.unwrap();
402                    assert_eq!(msg.payload, request1);
403                    server.respond(msg, response1.clone()).await.unwrap();
404
405                    let msg = open_worktree_rx.recv().await.unwrap();
406                    assert_eq!(msg.payload, request2.clone());
407                    server.respond(msg, response2.clone()).await.unwrap();
408
409                    let msg = open_buffer_rx.recv().await.unwrap();
410                    assert_eq!(msg.payload, request3.clone());
411                    server.respond(msg, response3.clone()).await.unwrap();
412
413                    let msg = open_buffer_rx.recv().await.unwrap();
414                    assert_eq!(msg.payload, request4.clone());
415                    server.respond(msg, response4.clone()).await.unwrap();
416
417                    server_done_tx.send(()).await.unwrap();
418                }
419            })
420            .detach();
421
422            assert_eq!(
423                client1.request(client1_conn_id, request1).await.unwrap(),
424                response1
425            );
426            assert_eq!(
427                client2.request(client2_conn_id, request2).await.unwrap(),
428                response2
429            );
430            assert_eq!(
431                client2.request(client2_conn_id, request3).await.unwrap(),
432                response3
433            );
434            assert_eq!(
435                client1.request(client1_conn_id, request4).await.unwrap(),
436                response4
437            );
438
439            client1.disconnect(client1_conn_id).await;
440            client2.disconnect(client1_conn_id).await;
441
442            server_done_rx.recv().await.unwrap();
443        });
444    }
445
446    #[test]
447    fn test_disconnect() {
448        smol::block_on(async move {
449            let socket_dir_path = TempDir::new("drop-client").unwrap();
450            let socket_path = socket_dir_path.path().join(".sock");
451            let listener = UnixListener::bind(&socket_path).unwrap();
452            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
453            let (mut server_conn, _) = listener.accept().await.unwrap();
454
455            let client = Peer::new();
456            let connection_id = client.add_connection(client_conn).await;
457            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
458                barrier::channel();
459            let handle_messages = client.handle_messages(connection_id);
460            smol::spawn(async move {
461                handle_messages.await.ok();
462                incoming_messages_ended_tx.send(()).await.unwrap();
463            })
464            .detach();
465            client.disconnect(connection_id).await;
466
467            incoming_messages_ended_rx.recv().await;
468
469            let err = server_conn.write(&[]).await.unwrap_err();
470            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
471        });
472    }
473
474    #[test]
475    fn test_io_error() {
476        smol::block_on(async move {
477            let socket_dir_path = TempDir::new("io-error").unwrap();
478            let socket_path = socket_dir_path.path().join(".sock");
479            let _listener = UnixListener::bind(&socket_path).unwrap();
480            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
481            client_conn.close().await.unwrap();
482
483            let client = Peer::new();
484            let connection_id = client.add_connection(client_conn).await;
485            smol::spawn(client.handle_messages(connection_id)).detach();
486
487            let err = client
488                .request(
489                    connection_id,
490                    proto::Auth {
491                        user_id: 42,
492                        access_token: "token".to_string(),
493                    },
494                )
495                .await
496                .unwrap_err();
497            assert_eq!(
498                err.downcast_ref::<io::Error>().unwrap().kind(),
499                io::ErrorKind::BrokenPipe
500            );
501        });
502    }
503}