peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{
  5    future::{BoxFuture, Either},
  6    AsyncRead, AsyncWrite, FutureExt,
  7};
  8use postage::{
  9    barrier, mpsc, oneshot,
 10    prelude::{Sink, Stream},
 11};
 12use std::{
 13    any::TypeId,
 14    collections::{HashMap, HashSet},
 15    future::Future,
 16    pin::Pin,
 17    sync::{
 18        atomic::{self, AtomicU32},
 19        Arc,
 20    },
 21};
 22
 23type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
 24type BoxedReader = Pin<Box<dyn AsyncRead + 'static + Send>>;
 25
 26#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 27pub struct ConnectionId(u32);
 28
 29struct Connection {
 30    writer: Mutex<MessageStream<BoxedWriter>>,
 31    reader: Mutex<MessageStream<BoxedReader>>,
 32    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
 33    next_message_id: AtomicU32,
 34}
 35
 36type MessageHandler = Box<
 37    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 38>;
 39
 40pub struct TypedEnvelope<T> {
 41    id: u32,
 42    connection_id: ConnectionId,
 43    payload: T,
 44}
 45
 46impl<T> TypedEnvelope<T> {
 47    pub fn connection_id(&self) -> ConnectionId {
 48        self.connection_id
 49    }
 50
 51    pub fn payload(&self) -> &T {
 52        &self.payload
 53    }
 54}
 55
 56pub struct Peer {
 57    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
 58    connection_close_barriers: RwLock<HashMap<ConnectionId, barrier::Sender>>,
 59    message_handlers: RwLock<Vec<MessageHandler>>,
 60    handler_types: Mutex<HashSet<TypeId>>,
 61    next_connection_id: AtomicU32,
 62}
 63
 64impl Peer {
 65    pub fn new() -> Arc<Self> {
 66        Arc::new(Self {
 67            connections: Default::default(),
 68            connection_close_barriers: Default::default(),
 69            message_handlers: Default::default(),
 70            handler_types: Default::default(),
 71            next_connection_id: Default::default(),
 72        })
 73    }
 74
 75    pub async fn add_message_handler<T: EnvelopedMessage>(
 76        &self,
 77    ) -> mpsc::Receiver<TypedEnvelope<T>> {
 78        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
 79            panic!("duplicate handler type");
 80        }
 81
 82        let (tx, rx) = mpsc::channel(256);
 83        self.message_handlers
 84            .write()
 85            .await
 86            .push(Box::new(move |envelope, connection_id| {
 87                if envelope.as_ref().map_or(false, T::matches_envelope) {
 88                    let envelope = Option::take(envelope).unwrap();
 89                    let mut tx = tx.clone();
 90                    Some(
 91                        async move {
 92                            tx.send(TypedEnvelope {
 93                                id: envelope.id,
 94                                connection_id,
 95                                payload: T::from_envelope(envelope).unwrap(),
 96                            })
 97                            .await
 98                            .is_err()
 99                        }
100                        .boxed(),
101                    )
102                } else {
103                    None
104                }
105            }));
106        rx
107    }
108
109    pub async fn add_connection<Conn>(self: &Arc<Self>, conn: Conn) -> ConnectionId
110    where
111        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
112    {
113        let connection_id = ConnectionId(
114            self.next_connection_id
115                .fetch_add(1, atomic::Ordering::SeqCst),
116        );
117        self.connections.write().await.insert(
118            connection_id,
119            Arc::new(Connection {
120                reader: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
121                writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
122                response_channels: Default::default(),
123                next_message_id: Default::default(),
124            }),
125        );
126        connection_id
127    }
128
129    pub async fn disconnect(&self, connection_id: ConnectionId) {
130        self.connections.write().await.remove(&connection_id);
131        self.connection_close_barriers
132            .write()
133            .await
134            .remove(&connection_id);
135    }
136
137    pub fn handle_messages(
138        self: &Arc<Self>,
139        connection_id: ConnectionId,
140    ) -> impl Future<Output = Result<()>> + 'static {
141        let (close_tx, mut close_rx) = barrier::channel();
142        let this = self.clone();
143        async move {
144            this.connection_close_barriers
145                .write()
146                .await
147                .insert(connection_id, close_tx);
148            let connection = this.connection(connection_id).await?;
149            let closed = close_rx.recv();
150            futures::pin_mut!(closed);
151
152            loop {
153                let mut reader = connection.reader.lock().await;
154                let read_message = reader.read_message();
155                futures::pin_mut!(read_message);
156
157                match futures::future::select(read_message, &mut closed).await {
158                    Either::Left((Ok(incoming), _)) => {
159                        if let Some(responding_to) = incoming.responding_to {
160                            let channel = connection
161                                .response_channels
162                                .lock()
163                                .await
164                                .remove(&responding_to);
165                            if let Some(mut tx) = channel {
166                                tx.send(incoming).await.ok();
167                            } else {
168                                log::warn!(
169                                    "received RPC response to unknown request {}",
170                                    responding_to
171                                );
172                            }
173                        } else {
174                            let mut envelope = Some(incoming);
175                            let mut handler_index = None;
176                            let mut handler_was_dropped = false;
177                            for (i, handler) in
178                                this.message_handlers.read().await.iter().enumerate()
179                            {
180                                if let Some(future) = handler(&mut envelope, connection_id) {
181                                    handler_was_dropped = future.await;
182                                    handler_index = Some(i);
183                                    break;
184                                }
185                            }
186
187                            if let Some(handler_index) = handler_index {
188                                if handler_was_dropped {
189                                    drop(this.message_handlers.write().await.remove(handler_index));
190                                }
191                            } else {
192                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
193                            }
194                        }
195                    }
196                    Either::Left((Err(error), _)) => {
197                        log::warn!("received invalid RPC message: {}", error);
198                        Err(error)?;
199                    }
200                    Either::Right(_) => return Ok(()),
201                }
202            }
203        }
204    }
205
206    pub async fn receive<M: EnvelopedMessage>(
207        self: &Arc<Self>,
208        connection_id: ConnectionId,
209    ) -> Result<TypedEnvelope<M>> {
210        let connection = self.connection(connection_id).await?;
211        let envelope = connection.reader.lock().await.read_message().await?;
212        let id = envelope.id;
213        let payload =
214            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
215        Ok(TypedEnvelope {
216            id,
217            connection_id,
218            payload,
219        })
220    }
221
222    pub fn request<T: RequestMessage>(
223        self: &Arc<Self>,
224        connection_id: ConnectionId,
225        req: T,
226    ) -> impl Future<Output = Result<T::Response>> {
227        let this = self.clone();
228        let (tx, mut rx) = oneshot::channel();
229        async move {
230            let connection = this.connection(connection_id).await?;
231            let message_id = connection
232                .next_message_id
233                .fetch_add(1, atomic::Ordering::SeqCst);
234            connection
235                .response_channels
236                .lock()
237                .await
238                .insert(message_id, tx);
239            connection
240                .writer
241                .lock()
242                .await
243                .write_message(&req.into_envelope(message_id, None))
244                .await?;
245            let response = rx
246                .recv()
247                .await
248                .expect("response channel was unexpectedly dropped");
249            T::Response::from_envelope(response)
250                .ok_or_else(|| anyhow!("received response of the wrong type"))
251        }
252    }
253
254    pub fn send<T: EnvelopedMessage>(
255        self: &Arc<Self>,
256        connection_id: ConnectionId,
257        message: T,
258    ) -> impl Future<Output = Result<()>> {
259        let this = self.clone();
260        async move {
261            let connection = this.connection(connection_id).await?;
262            let message_id = connection
263                .next_message_id
264                .fetch_add(1, atomic::Ordering::SeqCst);
265            connection
266                .writer
267                .lock()
268                .await
269                .write_message(&message.into_envelope(message_id, None))
270                .await?;
271            Ok(())
272        }
273    }
274
275    pub fn respond<T: RequestMessage>(
276        self: &Arc<Self>,
277        request: TypedEnvelope<T>,
278        response: T::Response,
279    ) -> impl Future<Output = Result<()>> {
280        let this = self.clone();
281        async move {
282            let connection = this.connection(request.connection_id).await?;
283            let message_id = connection
284                .next_message_id
285                .fetch_add(1, atomic::Ordering::SeqCst);
286            connection
287                .writer
288                .lock()
289                .await
290                .write_message(&response.into_envelope(message_id, Some(request.id)))
291                .await?;
292            Ok(())
293        }
294    }
295
296    async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
297        Ok(self
298            .connections
299            .read()
300            .await
301            .get(&id)
302            .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
303            .clone())
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use smol::{
311        io::AsyncWriteExt,
312        net::unix::{UnixListener, UnixStream},
313    };
314    use std::io;
315    use tempdir::TempDir;
316
317    #[test]
318    fn test_request_response() {
319        smol::block_on(async move {
320            // create socket
321            let socket_dir_path = TempDir::new("test-request-response").unwrap();
322            let socket_path = socket_dir_path.path().join("test.sock");
323            let listener = UnixListener::bind(&socket_path).unwrap();
324
325            // create 2 clients connected to 1 server
326            let server = Peer::new();
327            let client1 = Peer::new();
328            let client2 = Peer::new();
329            let client1_conn_id = client1
330                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
331                .await;
332            let client2_conn_id = client2
333                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
334                .await;
335            let server_conn_id1 = server
336                .add_connection(listener.accept().await.unwrap().0)
337                .await;
338            let server_conn_id2 = server
339                .add_connection(listener.accept().await.unwrap().0)
340                .await;
341            smol::spawn(client1.handle_messages(client1_conn_id)).detach();
342            smol::spawn(client2.handle_messages(client2_conn_id)).detach();
343            smol::spawn(server.handle_messages(server_conn_id1)).detach();
344            smol::spawn(server.handle_messages(server_conn_id2)).detach();
345
346            // define the expected requests and responses
347            let request1 = proto::OpenWorktree {
348                worktree_id: 101,
349                access_token: "first-worktree-access-token".to_string(),
350            };
351            let response1 = proto::OpenWorktreeResponse {
352                worktree: Some(proto::Worktree {
353                    paths: vec![b"path/one".to_vec()],
354                }),
355            };
356            let request2 = proto::OpenWorktree {
357                worktree_id: 102,
358                access_token: "second-worktree-access-token".to_string(),
359            };
360            let response2 = proto::OpenWorktreeResponse {
361                worktree: Some(proto::Worktree {
362                    paths: vec![b"path/two".to_vec(), b"path/three".to_vec()],
363                }),
364            };
365            let request3 = proto::OpenBuffer {
366                worktree_id: 102,
367                path: b"path/two".to_vec(),
368            };
369            let response3 = proto::OpenBufferResponse {
370                buffer: Some(proto::Buffer {
371                    id: 1001,
372                    path: b"path/two".to_vec(),
373                    content: b"path/two content".to_vec(),
374                    history: vec![],
375                }),
376            };
377            let request4 = proto::OpenBuffer {
378                worktree_id: 101,
379                path: b"path/one".to_vec(),
380            };
381            let response4 = proto::OpenBufferResponse {
382                buffer: Some(proto::Buffer {
383                    id: 1002,
384                    path: b"path/one".to_vec(),
385                    content: b"path/one content".to_vec(),
386                    history: vec![],
387                }),
388            };
389
390            // on the server, respond to two requests for each client
391            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
392            let mut open_worktree_rx = server.add_message_handler::<proto::OpenWorktree>().await;
393            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
394            smol::spawn({
395                let request1 = request1.clone();
396                let request2 = request2.clone();
397                let request3 = request3.clone();
398                let request4 = request4.clone();
399                let response1 = response1.clone();
400                let response2 = response2.clone();
401                let response3 = response3.clone();
402                let response4 = response4.clone();
403                async move {
404                    let msg = open_worktree_rx.recv().await.unwrap();
405                    assert_eq!(msg.payload, request1);
406                    server.respond(msg, response1.clone()).await.unwrap();
407
408                    let msg = open_worktree_rx.recv().await.unwrap();
409                    assert_eq!(msg.payload, request2.clone());
410                    server.respond(msg, response2.clone()).await.unwrap();
411
412                    let msg = open_buffer_rx.recv().await.unwrap();
413                    assert_eq!(msg.payload, request3.clone());
414                    server.respond(msg, response3.clone()).await.unwrap();
415
416                    let msg = open_buffer_rx.recv().await.unwrap();
417                    assert_eq!(msg.payload, request4.clone());
418                    server.respond(msg, response4.clone()).await.unwrap();
419
420                    server_done_tx.send(()).await.unwrap();
421                }
422            })
423            .detach();
424
425            assert_eq!(
426                client1.request(client1_conn_id, request1).await.unwrap(),
427                response1
428            );
429            assert_eq!(
430                client2.request(client2_conn_id, request2).await.unwrap(),
431                response2
432            );
433            assert_eq!(
434                client2.request(client2_conn_id, request3).await.unwrap(),
435                response3
436            );
437            assert_eq!(
438                client1.request(client1_conn_id, request4).await.unwrap(),
439                response4
440            );
441
442            client1.disconnect(client1_conn_id).await;
443            client2.disconnect(client1_conn_id).await;
444
445            server_done_rx.recv().await.unwrap();
446        });
447    }
448
449    #[test]
450    fn test_disconnect() {
451        smol::block_on(async move {
452            let socket_dir_path = TempDir::new("drop-client").unwrap();
453            let socket_path = socket_dir_path.path().join(".sock");
454            let listener = UnixListener::bind(&socket_path).unwrap();
455            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
456            let (mut server_conn, _) = listener.accept().await.unwrap();
457
458            let client = Peer::new();
459            let connection_id = client.add_connection(client_conn).await;
460            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
461                barrier::channel();
462            let handle_messages = client.handle_messages(connection_id);
463            smol::spawn(async move {
464                handle_messages.await.ok();
465                incoming_messages_ended_tx.send(()).await.unwrap();
466            })
467            .detach();
468            client.disconnect(connection_id).await;
469
470            incoming_messages_ended_rx.recv().await;
471
472            let err = server_conn.write(&[]).await.unwrap_err();
473            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
474        });
475    }
476
477    #[test]
478    fn test_io_error() {
479        smol::block_on(async move {
480            let socket_dir_path = TempDir::new("io-error").unwrap();
481            let socket_path = socket_dir_path.path().join(".sock");
482            let _listener = UnixListener::bind(&socket_path).unwrap();
483            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
484            client_conn.close().await.unwrap();
485
486            let client = Peer::new();
487            let connection_id = client.add_connection(client_conn).await;
488            smol::spawn(client.handle_messages(connection_id)).detach();
489
490            let err = client
491                .request(
492                    connection_id,
493                    proto::Auth {
494                        user_id: 42,
495                        access_token: "token".to_string(),
496                    },
497                )
498                .await
499                .unwrap_err();
500            assert_eq!(
501                err.downcast_ref::<io::Error>().unwrap().kind(),
502                io::ErrorKind::BrokenPipe
503            );
504        });
505    }
506}