peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{
  5    future::{BoxFuture, Either},
  6    AsyncRead, AsyncWrite, FutureExt,
  7};
  8use postage::{
  9    barrier, mpsc, oneshot,
 10    prelude::{Sink, Stream},
 11};
 12use std::{
 13    any::TypeId,
 14    collections::{HashMap, HashSet},
 15    future::Future,
 16    pin::Pin,
 17    sync::{
 18        atomic::{self, AtomicU32},
 19        Arc,
 20    },
 21};
 22
 23type BoxedWriter = Pin<Box<dyn AsyncWrite + 'static + Send>>;
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 26pub struct ConnectionId(u32);
 27
 28struct Connection {
 29    writer: Mutex<MessageStream<BoxedWriter>>,
 30    response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
 31    next_message_id: AtomicU32,
 32}
 33
 34type MessageHandler = Box<
 35    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 36>;
 37
 38pub struct TypedEnvelope<T> {
 39    id: u32,
 40    connection_id: ConnectionId,
 41    payload: T,
 42}
 43
 44impl<T> TypedEnvelope<T> {
 45    pub fn connection_id(&self) -> ConnectionId {
 46        self.connection_id
 47    }
 48
 49    pub fn payload(&self) -> &T {
 50        &self.payload
 51    }
 52}
 53
 54pub struct Peer {
 55    connections: RwLock<HashMap<ConnectionId, (Arc<Connection>, barrier::Sender)>>,
 56    message_handlers: RwLock<Vec<MessageHandler>>,
 57    handler_types: Mutex<HashSet<TypeId>>,
 58    next_connection_id: AtomicU32,
 59}
 60
 61impl Peer {
 62    pub fn new() -> Arc<Self> {
 63        Arc::new(Self {
 64            connections: Default::default(),
 65            message_handlers: Default::default(),
 66            handler_types: Default::default(),
 67            next_connection_id: Default::default(),
 68        })
 69    }
 70
 71    pub async fn add_message_handler<T: EnvelopedMessage>(
 72        &self,
 73    ) -> mpsc::Receiver<TypedEnvelope<T>> {
 74        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
 75            panic!("duplicate handler type");
 76        }
 77
 78        let (tx, rx) = mpsc::channel(256);
 79        self.message_handlers
 80            .write()
 81            .await
 82            .push(Box::new(move |envelope, connection_id| {
 83                if envelope.as_ref().map_or(false, T::matches_envelope) {
 84                    let envelope = Option::take(envelope).unwrap();
 85                    let mut tx = tx.clone();
 86                    Some(
 87                        async move {
 88                            tx.send(TypedEnvelope {
 89                                id: envelope.id,
 90                                connection_id,
 91                                payload: T::from_envelope(envelope).unwrap(),
 92                            })
 93                            .await
 94                            .is_err()
 95                        }
 96                        .boxed(),
 97                    )
 98                } else {
 99                    None
100                }
101            }));
102        rx
103    }
104
105    pub async fn add_connection<Conn>(
106        self: &Arc<Self>,
107        conn: Conn,
108    ) -> (ConnectionId, impl Future<Output = Result<()>>)
109    where
110        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
111    {
112        let connection_id = ConnectionId(
113            self.next_connection_id
114                .fetch_add(1, atomic::Ordering::SeqCst),
115        );
116        let (close_tx, mut close_rx) = barrier::channel();
117        let connection = Arc::new(Connection {
118            writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
119            response_channels: Default::default(),
120            next_message_id: Default::default(),
121        });
122
123        self.connections
124            .write()
125            .await
126            .insert(connection_id, (connection.clone(), close_tx));
127
128        let this = self.clone();
129        let handler_future = async move {
130            let closed = close_rx.recv();
131            futures::pin_mut!(closed);
132
133            let mut stream = MessageStream::new(conn);
134            loop {
135                let read_message = stream.read_message();
136                futures::pin_mut!(read_message);
137
138                match futures::future::select(read_message, &mut closed).await {
139                    Either::Left((Ok(incoming), _)) => {
140                        if let Some(responding_to) = incoming.responding_to {
141                            let channel = connection
142                                .response_channels
143                                .lock()
144                                .await
145                                .remove(&responding_to);
146                            if let Some(mut tx) = channel {
147                                tx.send(incoming).await.ok();
148                            } else {
149                                log::warn!(
150                                    "received RPC response to unknown request {}",
151                                    responding_to
152                                );
153                            }
154                        } else {
155                            let mut envelope = Some(incoming);
156                            let mut handler_index = None;
157                            let mut handler_was_dropped = false;
158                            for (i, handler) in
159                                this.message_handlers.read().await.iter().enumerate()
160                            {
161                                if let Some(future) = handler(&mut envelope, connection_id) {
162                                    handler_was_dropped = future.await;
163                                    handler_index = Some(i);
164                                    break;
165                                }
166                            }
167
168                            if let Some(handler_index) = handler_index {
169                                if handler_was_dropped {
170                                    drop(this.message_handlers.write().await.remove(handler_index));
171                                }
172                            } else {
173                                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
174                            }
175                        }
176                    }
177                    Either::Left((Err(error), _)) => {
178                        log::warn!("received invalid RPC message: {}", error);
179                        Err(error)?;
180                    }
181                    Either::Right(_) => return Ok(()),
182                }
183            }
184        };
185
186        (connection_id, handler_future)
187    }
188
189    pub async fn disconnect(&self, connection_id: ConnectionId) {
190        self.connections.write().await.remove(&connection_id);
191    }
192
193    pub fn request<T: RequestMessage>(
194        self: &Arc<Self>,
195        connection_id: ConnectionId,
196        req: T,
197    ) -> impl Future<Output = Result<T::Response>> {
198        let this = self.clone();
199        let (tx, mut rx) = oneshot::channel();
200        async move {
201            let connection = this.connection(connection_id).await?;
202            let message_id = connection
203                .next_message_id
204                .fetch_add(1, atomic::Ordering::SeqCst);
205            connection
206                .response_channels
207                .lock()
208                .await
209                .insert(message_id, tx);
210            connection
211                .writer
212                .lock()
213                .await
214                .write_message(&req.into_envelope(message_id, None))
215                .await?;
216            let response = rx
217                .recv()
218                .await
219                .expect("response channel was unexpectedly dropped");
220            T::Response::from_envelope(response)
221                .ok_or_else(|| anyhow!("received response of the wrong type"))
222        }
223    }
224
225    pub fn send<T: EnvelopedMessage>(
226        self: &Arc<Self>,
227        connection_id: ConnectionId,
228        message: T,
229    ) -> impl Future<Output = Result<()>> {
230        let this = self.clone();
231        async move {
232            let connection = this.connection(connection_id).await?;
233            let message_id = connection
234                .next_message_id
235                .fetch_add(1, atomic::Ordering::SeqCst);
236            connection
237                .writer
238                .lock()
239                .await
240                .write_message(&message.into_envelope(message_id, None))
241                .await?;
242            Ok(())
243        }
244    }
245
246    pub fn respond<T: RequestMessage>(
247        self: &Arc<Self>,
248        request: TypedEnvelope<T>,
249        response: T::Response,
250    ) -> impl Future<Output = Result<()>> {
251        let this = self.clone();
252        async move {
253            let connection = this.connection(request.connection_id).await?;
254            let message_id = connection
255                .next_message_id
256                .fetch_add(1, atomic::Ordering::SeqCst);
257            connection
258                .writer
259                .lock()
260                .await
261                .write_message(&response.into_envelope(message_id, Some(request.id)))
262                .await?;
263            Ok(())
264        }
265    }
266
267    async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
268        Ok(self
269            .connections
270            .read()
271            .await
272            .get(&id)
273            .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
274            .0
275            .clone())
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use smol::{
283        io::AsyncWriteExt,
284        net::unix::{UnixListener, UnixStream},
285    };
286    use std::io;
287    use tempdir::TempDir;
288
289    #[test]
290    fn test_request_response() {
291        smol::block_on(async move {
292            // create socket
293            let socket_dir_path = TempDir::new("test-request-response").unwrap();
294            let socket_path = socket_dir_path.path().join("test.sock");
295            let listener = UnixListener::bind(&socket_path).unwrap();
296
297            // create 2 clients connected to 1 server
298            let server = Peer::new();
299            let client1 = Peer::new();
300            let client2 = Peer::new();
301            let (client1_conn_id, f1) = client1
302                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
303                .await;
304            let (client2_conn_id, f2) = client2
305                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
306                .await;
307            let (_, f3) = server
308                .add_connection(listener.accept().await.unwrap().0)
309                .await;
310            let (_, f4) = server
311                .add_connection(listener.accept().await.unwrap().0)
312                .await;
313            smol::spawn(f1).detach();
314            smol::spawn(f2).detach();
315            smol::spawn(f3).detach();
316            smol::spawn(f4).detach();
317
318            // define the expected requests and responses
319            let request1 = proto::OpenWorktree {
320                worktree_id: 101,
321                access_token: "first-worktree-access-token".to_string(),
322            };
323            let response1 = proto::OpenWorktreeResponse {
324                worktree: Some(proto::Worktree {
325                    paths: vec![b"path/one".to_vec()],
326                }),
327            };
328            let request2 = proto::OpenWorktree {
329                worktree_id: 102,
330                access_token: "second-worktree-access-token".to_string(),
331            };
332            let response2 = proto::OpenWorktreeResponse {
333                worktree: Some(proto::Worktree {
334                    paths: vec![b"path/two".to_vec(), b"path/three".to_vec()],
335                }),
336            };
337            let request3 = proto::OpenBuffer {
338                worktree_id: 102,
339                path: b"path/two".to_vec(),
340            };
341            let response3 = proto::OpenBufferResponse {
342                buffer: Some(proto::Buffer {
343                    id: 1001,
344                    path: b"path/two".to_vec(),
345                    content: b"path/two content".to_vec(),
346                    history: vec![],
347                }),
348            };
349            let request4 = proto::OpenBuffer {
350                worktree_id: 101,
351                path: b"path/one".to_vec(),
352            };
353            let response4 = proto::OpenBufferResponse {
354                buffer: Some(proto::Buffer {
355                    id: 1002,
356                    path: b"path/one".to_vec(),
357                    content: b"path/one content".to_vec(),
358                    history: vec![],
359                }),
360            };
361
362            // on the server, respond to two requests for each client
363            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
364            let mut open_worktree_rx = server.add_message_handler::<proto::OpenWorktree>().await;
365            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
366            smol::spawn({
367                let request1 = request1.clone();
368                let request2 = request2.clone();
369                let request3 = request3.clone();
370                let request4 = request4.clone();
371                let response1 = response1.clone();
372                let response2 = response2.clone();
373                let response3 = response3.clone();
374                let response4 = response4.clone();
375                async move {
376                    let msg = open_worktree_rx.recv().await.unwrap();
377                    assert_eq!(msg.payload, request1);
378                    server.respond(msg, response1.clone()).await.unwrap();
379
380                    let msg = open_worktree_rx.recv().await.unwrap();
381                    assert_eq!(msg.payload, request2.clone());
382                    server.respond(msg, response2.clone()).await.unwrap();
383
384                    let msg = open_buffer_rx.recv().await.unwrap();
385                    assert_eq!(msg.payload, request3.clone());
386                    server.respond(msg, response3.clone()).await.unwrap();
387
388                    let msg = open_buffer_rx.recv().await.unwrap();
389                    assert_eq!(msg.payload, request4.clone());
390                    server.respond(msg, response4.clone()).await.unwrap();
391
392                    server_done_tx.send(()).await.unwrap();
393                }
394            })
395            .detach();
396
397            assert_eq!(
398                client1.request(client1_conn_id, request1).await.unwrap(),
399                response1
400            );
401            assert_eq!(
402                client2.request(client2_conn_id, request2).await.unwrap(),
403                response2
404            );
405            assert_eq!(
406                client2.request(client2_conn_id, request3).await.unwrap(),
407                response3
408            );
409            assert_eq!(
410                client1.request(client1_conn_id, request4).await.unwrap(),
411                response4
412            );
413
414            client1.disconnect(client1_conn_id).await;
415            client2.disconnect(client1_conn_id).await;
416
417            server_done_rx.recv().await.unwrap();
418        });
419    }
420
421    #[test]
422    fn test_disconnect() {
423        smol::block_on(async move {
424            let socket_dir_path = TempDir::new("drop-client").unwrap();
425            let socket_path = socket_dir_path.path().join(".sock");
426            let listener = UnixListener::bind(&socket_path).unwrap();
427            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
428            let (mut server_conn, _) = listener.accept().await.unwrap();
429
430            let client = Peer::new();
431            let (connection_id, handler) = client.add_connection(client_conn).await;
432            smol::spawn(handler).detach();
433            client.disconnect(connection_id).await;
434
435            // Try sending an empty payload over and over, until the client is dropped and hangs up.
436            loop {
437                match server_conn.write(&[]).await {
438                    Ok(_) => {}
439                    Err(err) => {
440                        if err.kind() == io::ErrorKind::BrokenPipe {
441                            break;
442                        }
443                    }
444                }
445            }
446        });
447    }
448
449    #[test]
450    fn test_io_error() {
451        smol::block_on(async move {
452            let socket_dir_path = TempDir::new("io-error").unwrap();
453            let socket_path = socket_dir_path.path().join(".sock");
454            let _listener = UnixListener::bind(&socket_path).unwrap();
455            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
456            client_conn.close().await.unwrap();
457
458            let client = Peer::new();
459            let (connection_id, handler) = client.add_connection(client_conn).await;
460            smol::spawn(handler).detach();
461
462            let err = client
463                .request(
464                    connection_id,
465                    proto::Auth {
466                        user_id: 42,
467                        access_token: "token".to_string(),
468                    },
469                )
470                .await
471                .unwrap_err();
472            assert_eq!(
473                err.downcast_ref::<io::Error>().unwrap().kind(),
474                io::ErrorKind::BrokenPipe
475            );
476        });
477    }
478}