peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{
  5    future::{BoxFuture, Either},
  6    AsyncRead, AsyncWrite, FutureExt,
  7};
  8use postage::{
  9    barrier, mpsc, oneshot,
 10    prelude::{Sink, Stream},
 11};
 12use std::{
 13    any::TypeId,
 14    collections::{HashMap, HashSet},
 15    future::Future,
 16    pin::Pin,
 17    sync::{
 18        atomic::{self, AtomicU32},
 19        Arc,
 20    },
 21};
 22
 23type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
 24type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
 25
 26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 27pub struct ConnectionId(u32);
 28
 29struct Connection {
 30    writer: Mutex<MessageStream<BoxedWriter>>,
 31    reader: Mutex<MessageStream<BoxedReader>>,
 32    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
 33    next_message_id: AtomicU32,
 34}
 35
 36type MessageHandler = Box<
 37    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 38>;
 39
 40pub struct TypedEnvelope<T> {
 41    pub id: u32,
 42    pub connection_id: ConnectionId,
 43    pub payload: T,
 44}
 45
 46pub struct Peer {
 47    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
 48    connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
 49    message_handlers: RwLock<Vec<MessageHandler>>,
 50    handler_types: Mutex<HashSet<TypeId>>,
 51    next_connection_id: AtomicU32,
 52}
 53
 54impl Peer {
 55    pub fn new() -> Arc<Self> {
 56        Arc::new(Self {
 57            connections: Default::default(),
 58            connection_close_barriers: Default::default(),
 59            message_handlers: Default::default(),
 60            handler_types: Default::default(),
 61            next_connection_id: Default::default(),
 62        })
 63    }
 64
 65    pub async fn add_message_handler<T: EnvelopedMessage>(
 66        &self,
 67    ) -> mpsc::Receiver<TypedEnvelope<T>> {
 68        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
 69            panic!("duplicate handler type");
 70        }
 71
 72        let (tx, rx) = mpsc::channel(256);
 73        self.message_handlers
 74            .write()
 75            .await
 76            .push(Box::new(move |envelope, connection_id| {
 77                if envelope.as_ref().map_or(false, T::matches_envelope) {
 78                    let envelope = Option::take(envelope).unwrap();
 79                    let mut tx = tx.clone();
 80                    Some(
 81                        async move {
 82                            tx.send(TypedEnvelope {
 83                                id: envelope.id,
 84                                connection_id,
 85                                payload: T::from_envelope(envelope).unwrap(),
 86                            })
 87                            .await
 88                            .is_err()
 89                        }
 90                        .boxed(),
 91                    )
 92                } else {
 93                    None
 94                }
 95            }));
 96        rx
 97    }
 98
 99    pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
100    where
101        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
102    {
103        let connection_id = ConnectionId(
104            self.next_connection_id
105                .fetch_add(1, atomic::Ordering::SeqCst),
106        );
107        self.connections.write().await.insert(
108            connection_id,
109            Arc::new(Connection {
110                reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
111                writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
112                response_channels: Default::default(),
113                next_message_id: Default::default(),
114            }),
115        );
116        connection_id
117    }
118
119    pub async fn disconnect(&self, connection_id: ConnectionId) {
120        self.connections.write().await.remove(&connection_id);
121        self.connection_close_barriers
122            .write()
123            .await
124            .remove(&connection_id);
125    }
126
127    pub fn handle_messages(
128        self: &Arc<Self>,
129        connection_id: ConnectionId,
130    ) -> impl Future<Output = Result<()>> + 'static {
131        let (close_tx, mut close_rx) = barrier::channel();
132        let this = self.clone();
133        async move {
134            this.connection_close_barriers
135                .write()
136                .await
137                .insert(connection_id, close_tx);
138            let connection = this.connection(connection_id).await?;
139            let closed = close_rx.recv();
140            futures::pin_mut!(closed);
141
142            loop {
143                let mut reader = connection.reader.lock().await;
144                let read_message = reader.read_message();
145                futures::pin_mut!(read_message);
146
147                match futures::future::select(read_message, &mut closed).await {
148                    Either::Left((Ok(incoming), _)) => {
149                        if let Some(responding_to) = incoming.responding_to {
150                            let channel = connection
151                                .response_channels
152                                .lock()
153                                .await
154                                .remove(&responding_to);
155                            if let Some(mut tx) = channel {
156                                tx.send(incoming).await.ok();
157                            } else {
158                                log::warn!(
159                                    "received RPC response to unknown request {}",
160                                    responding_to
161                                );
162                            }
163                        } else {
164                            let mut envelope = Some(incoming);
165                            let mut handler_index = None;
166                            let mut handler_was_dropped = false;
167                            for (i, handler) in
168                                this.message_handlers.read().await.iter().enumerate()
169                            {
170                                if let Some(future) = handler(&mut envelope, connection_id) {
171                                    handler_was_dropped = future.await;
172                                    handler_index = Some(i);
173                                    break;
174                                }
175                            }
176
177                            if let Some(handler_index) = handler_index {
178                                if handler_was_dropped {
179                                    drop(this.message_handlers.write().await.remove(handler_index));
180                                }
181                            } else {
182                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
183                            }
184                        }
185                    }
186                    Either::Left((Err(error), _)) => {
187                        log::warn!("received invalid RPC message: {}", error);
188                        Err(error)?;
189                    }
190                    Either::Right(_) => return Ok(()),
191                }
192            }
193        }
194    }
195
196    pub async fn receive<M: EnvelopedMessage>(
197        self: &Arc<Self>,
198        connection_id: ConnectionId,
199    ) -> Result<TypedEnvelope<M>> {
200        let connection = self.connection(connection_id).await?;
201        let envelope = connection.reader.lock().await.read_message().await?;
202        let id = envelope.id;
203        let payload =
204            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
205        Ok(TypedEnvelope {
206            id,
207            connection_id,
208            payload,
209        })
210    }
211
212    pub fn request<T: RequestMessage>(
213        self: &Arc<Self>,
214        connection_id: ConnectionId,
215        req: T,
216    ) -> impl Future<Output = Result<T::Response>> {
217        let this = self.clone();
218        let (tx, mut rx) = oneshot::channel();
219        async move {
220            let connection = this.connection(connection_id).await?;
221            let message_id = connection
222                .next_message_id
223                .fetch_add(1, atomic::Ordering::SeqCst);
224            connection
225                .response_channels
226                .lock()
227                .await
228                .insert(message_id, tx);
229            connection
230                .writer
231                .lock()
232                .await
233                .write_message(&req.into_envelope(message_id, None))
234                .await?;
235            let response = rx
236                .recv()
237                .await
238                .expect("response channel was unexpectedly dropped");
239            T::Response::from_envelope(response)
240                .ok_or_else(|| anyhow!("received response of the wrong type"))
241        }
242    }
243
244    pub fn send<T: EnvelopedMessage>(
245        self: &Arc<Self>,
246        connection_id: ConnectionId,
247        message: T,
248    ) -> impl Future<Output = Result<()>> {
249        let this = self.clone();
250        async move {
251            let connection = this.connection(connection_id).await?;
252            let message_id = connection
253                .next_message_id
254                .fetch_add(1, atomic::Ordering::SeqCst);
255            connection
256                .writer
257                .lock()
258                .await
259                .write_message(&message.into_envelope(message_id, None))
260                .await?;
261            Ok(())
262        }
263    }
264
265    pub fn respond<T: RequestMessage>(
266        self: &Arc<Self>,
267        request: TypedEnvelope<T>,
268        response: T::Response,
269    ) -> impl Future<Output = Result<()>> {
270        let this = self.clone();
271        async move {
272            let connection = this.connection(request.connection_id).await?;
273            let message_id = connection
274                .next_message_id
275                .fetch_add(1, atomic::Ordering::SeqCst);
276            connection
277                .writer
278                .lock()
279                .await
280                .write_message(&response.into_envelope(message_id, Some(request.id)))
281                .await?;
282            Ok(())
283        }
284    }
285
286    async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
287        Ok(self
288            .connections
289            .read()
290            .await
291            .get(&id)
292            .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
293            .clone())
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use smol::{
301        io::AsyncWriteExt,
302        net::unix::{UnixListener, UnixStream},
303    };
304    use std::io;
305    use tempdir::TempDir;
306
307    #[test]
308    fn test_request_response() {
309        smol::block_on(async move {
310            // create socket
311            let socket_dir_path = TempDir::new("test-request-response").unwrap();
312            let socket_path = socket_dir_path.path().join("test.sock");
313            let listener = UnixListener::bind(&socket_path).unwrap();
314
315            // create 2 clients connected to 1 server
316            let server = Peer::new();
317            let client1 = Peer::new();
318            let client2 = Peer::new();
319            let client1_conn_id = client1
320                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
321                .await;
322            let client2_conn_id = client2
323                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
324                .await;
325            let server_conn_id1 = server
326                .add_connection(listener.accept().await.unwrap().0)
327                .await;
328            let server_conn_id2 = server
329                .add_connection(listener.accept().await.unwrap().0)
330                .await;
331            smol::spawn(client1.handle_messages(client1_conn_id)).detach();
332            smol::spawn(client2.handle_messages(client2_conn_id)).detach();
333            smol::spawn(server.handle_messages(server_conn_id1)).detach();
334            smol::spawn(server.handle_messages(server_conn_id2)).detach();
335
336            // define the expected requests and responses
337            let request1 = proto::OpenWorktree {
338                worktree_id: 101,
339                access_token: "first-worktree-access-token".to_string(),
340            };
341            let response1 = proto::OpenWorktreeResponse {
342                worktree: Some(proto::Worktree {
343                    paths: vec![b"path/one".to_vec()],
344                }),
345            };
346            let request2 = proto::OpenWorktree {
347                worktree_id: 102,
348                access_token: "second-worktree-access-token".to_string(),
349            };
350            let response2 = proto::OpenWorktreeResponse {
351                worktree: Some(proto::Worktree {
352                    paths: vec![b"path/two".to_vec(), b"path/three".to_vec()],
353                }),
354            };
355            let request3 = proto::OpenBuffer {
356                worktree_id: 102,
357                path: b"path/two".to_vec(),
358            };
359            let response3 = proto::OpenBufferResponse {
360                buffer: Some(proto::Buffer {
361                    id: 1001,
362                    path: b"path/two".to_vec(),
363                    content: b"path/two content".to_vec(),
364                    history: vec![],
365                }),
366            };
367            let request4 = proto::OpenBuffer {
368                worktree_id: 101,
369                path: b"path/one".to_vec(),
370            };
371            let response4 = proto::OpenBufferResponse {
372                buffer: Some(proto::Buffer {
373                    id: 1002,
374                    path: b"path/one".to_vec(),
375                    content: b"path/one content".to_vec(),
376                    history: vec![],
377                }),
378            };
379
380            // on the server, respond to two requests for each client
381            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
382            let mut open_worktree_rx = server.add_message_handler::<proto::OpenWorktree>().await;
383            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
384            smol::spawn({
385                let request1 = request1.clone();
386                let request2 = request2.clone();
387                let request3 = request3.clone();
388                let request4 = request4.clone();
389                let response1 = response1.clone();
390                let response2 = response2.clone();
391                let response3 = response3.clone();
392                let response4 = response4.clone();
393                async move {
394                    let msg = open_worktree_rx.recv().await.unwrap();
395                    assert_eq!(msg.payload, request1);
396                    server.respond(msg, response1.clone()).await.unwrap();
397
398                    let msg = open_worktree_rx.recv().await.unwrap();
399                    assert_eq!(msg.payload, request2.clone());
400                    server.respond(msg, response2.clone()).await.unwrap();
401
402                    let msg = open_buffer_rx.recv().await.unwrap();
403                    assert_eq!(msg.payload, request3.clone());
404                    server.respond(msg, response3.clone()).await.unwrap();
405
406                    let msg = open_buffer_rx.recv().await.unwrap();
407                    assert_eq!(msg.payload, request4.clone());
408                    server.respond(msg, response4.clone()).await.unwrap();
409
410                    server_done_tx.send(()).await.unwrap();
411                }
412            })
413            .detach();
414
415            assert_eq!(
416                client1.request(client1_conn_id, request1).await.unwrap(),
417                response1
418            );
419            assert_eq!(
420                client2.request(client2_conn_id, request2).await.unwrap(),
421                response2
422            );
423            assert_eq!(
424                client2.request(client2_conn_id, request3).await.unwrap(),
425                response3
426            );
427            assert_eq!(
428                client1.request(client1_conn_id, request4).await.unwrap(),
429                response4
430            );
431
432            client1.disconnect(client1_conn_id).await;
433            client2.disconnect(client1_conn_id).await;
434
435            server_done_rx.recv().await.unwrap();
436        });
437    }
438
439    #[test]
440    fn test_disconnect() {
441        smol::block_on(async move {
442            let socket_dir_path = TempDir::new("drop-client").unwrap();
443            let socket_path = socket_dir_path.path().join(".sock");
444            let listener = UnixListener::bind(&socket_path).unwrap();
445            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
446            let (mut server_conn, _) = listener.accept().await.unwrap();
447
448            let client = Peer::new();
449            let connection_id = client.add_connection(client_conn).await;
450            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
451                barrier::channel();
452            let handle_messages = client.handle_messages(connection_id);
453            smol::spawn(async move {
454                handle_messages.await.ok();
455                incoming_messages_ended_tx.send(()).await.unwrap();
456            })
457            .detach();
458            client.disconnect(connection_id).await;
459
460            incoming_messages_ended_rx.recv().await;
461
462            let err = server_conn.write(&[]).await.unwrap_err();
463            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
464        });
465    }
466
467    #[test]
468    fn test_io_error() {
469        smol::block_on(async move {
470            let socket_dir_path = TempDir::new("io-error").unwrap();
471            let socket_path = socket_dir_path.path().join(".sock");
472            let _listener = UnixListener::bind(&socket_path).unwrap();
473            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
474            client_conn.close().await.unwrap();
475
476            let client = Peer::new();
477            let connection_id = client.add_connection(client_conn).await;
478            smol::spawn(client.handle_messages(connection_id)).detach();
479
480            let err = client
481                .request(
482                    connection_id,
483                    proto::Auth {
484                        user_id: 42,
485                        access_token: "token".to_string(),
486                    },
487                )
488                .await
489                .unwrap_err();
490            assert_eq!(
491                err.downcast_ref::<io::Error>().unwrap().kind(),
492                io::ErrorKind::BrokenPipe
493            );
494        });
495    }
496}