peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Context, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{future::BoxFuture, AsyncRead, AsyncWrite, FutureExt};
  5use postage::{
  6    mpsc,
  7    prelude::{Sink, Stream},
  8};
  9use std::{
 10    any::TypeId,
 11    collections::{HashMap, HashSet},
 12    fmt,
 13    future::Future,
 14    marker::PhantomData,
 15    sync::{
 16        atomic::{self, AtomicU32},
 17        Arc,
 18    },
 19};
 20
 21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 22pub struct ConnectionId(pub u32);
 23
 24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 25pub struct PeerId(pub u32);
 26
 27type MessageHandler = Box<
 28    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 29>;
 30
 31#[derive(Clone, Copy)]
 32pub struct Receipt<T> {
 33    sender_id: ConnectionId,
 34    message_id: u32,
 35    payload_type: PhantomData<T>,
 36}
 37
 38pub struct TypedEnvelope<T> {
 39    pub sender_id: ConnectionId,
 40    original_sender_id: Option<PeerId>,
 41    pub message_id: u32,
 42    pub payload: T,
 43}
 44
 45impl<T> TypedEnvelope<T> {
 46    pub fn original_sender_id(&self) -> Result<PeerId> {
 47        self.original_sender_id
 48            .ok_or_else(|| anyhow!("missing original_sender_id"))
 49    }
 50}
 51
 52impl<T: RequestMessage> TypedEnvelope<T> {
 53    pub fn receipt(&self) -> Receipt<T> {
 54        Receipt {
 55            sender_id: self.sender_id,
 56            message_id: self.message_id,
 57            payload_type: PhantomData,
 58        }
 59    }
 60}
 61
 62pub struct Peer {
 63    connections: RwLock<HashMap<ConnectionId, Connection>>,
 64    message_handlers: RwLock<Vec<MessageHandler>>,
 65    handler_types: Mutex<HashSet<TypeId>>,
 66    next_connection_id: AtomicU32,
 67}
 68
 69#[derive(Clone)]
 70struct Connection {
 71    outgoing_tx: mpsc::Sender<proto::Envelope>,
 72    next_message_id: Arc<AtomicU32>,
 73    response_channels: ResponseChannels,
 74}
 75
 76pub struct ConnectionHandler<Conn> {
 77    peer: Arc<Peer>,
 78    connection_id: ConnectionId,
 79    response_channels: ResponseChannels,
 80    outgoing_rx: mpsc::Receiver<proto::Envelope>,
 81    reader: MessageStream<Conn>,
 82    writer: MessageStream<Conn>,
 83}
 84
 85type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
 86
 87impl Peer {
 88    pub fn new() -> Arc<Self> {
 89        Arc::new(Self {
 90            connections: Default::default(),
 91            message_handlers: Default::default(),
 92            handler_types: Default::default(),
 93            next_connection_id: Default::default(),
 94        })
 95    }
 96
 97    pub async fn add_message_handler<T: EnvelopedMessage>(
 98        &self,
 99    ) -> mpsc::Receiver<TypedEnvelope<T>> {
100        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
101            panic!("duplicate handler type");
102        }
103
104        let (tx, rx) = mpsc::channel(256);
105        self.message_handlers
106            .write()
107            .await
108            .push(Box::new(move |envelope, connection_id| {
109                if envelope.as_ref().map_or(false, T::matches_envelope) {
110                    let envelope = Option::take(envelope).unwrap();
111                    let mut tx = tx.clone();
112                    Some(
113                        async move {
114                            tx.send(TypedEnvelope {
115                                sender_id: connection_id,
116                                original_sender_id: envelope.original_sender_id.map(PeerId),
117                                message_id: envelope.id,
118                                payload: T::from_envelope(envelope).unwrap(),
119                            })
120                            .await
121                            .is_err()
122                        }
123                        .boxed(),
124                    )
125                } else {
126                    None
127                }
128            }));
129        rx
130    }
131
132    pub async fn add_connection<Conn>(
133        self: &Arc<Self>,
134        conn: Conn,
135    ) -> (ConnectionId, ConnectionHandler<Conn>)
136    where
137        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
138    {
139        let connection_id = ConnectionId(
140            self.next_connection_id
141                .fetch_add(1, atomic::Ordering::SeqCst),
142        );
143        let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
144        let connection = Connection {
145            outgoing_tx,
146            next_message_id: Default::default(),
147            response_channels: Default::default(),
148        };
149        let handler = ConnectionHandler {
150            peer: self.clone(),
151            connection_id,
152            response_channels: connection.response_channels.clone(),
153            outgoing_rx,
154            reader: MessageStream::new(conn.clone()),
155            writer: MessageStream::new(conn),
156        };
157        self.connections
158            .write()
159            .await
160            .insert(connection_id, connection);
161        (connection_id, handler)
162    }
163
164    pub async fn disconnect(&self, connection_id: ConnectionId) {
165        self.connections.write().await.remove(&connection_id);
166    }
167
168    pub async fn reset(&self) {
169        self.connections.write().await.clear();
170        self.handler_types.lock().await.clear();
171        self.message_handlers.write().await.clear();
172    }
173
174    pub fn request<T: RequestMessage>(
175        self: &Arc<Self>,
176        receiver_id: ConnectionId,
177        request: T,
178    ) -> impl Future<Output = Result<T::Response>> {
179        self.request_internal(None, receiver_id, request)
180    }
181
182    pub fn forward_request<T: RequestMessage>(
183        self: &Arc<Self>,
184        sender_id: ConnectionId,
185        receiver_id: ConnectionId,
186        request: T,
187    ) -> impl Future<Output = Result<T::Response>> {
188        self.request_internal(Some(sender_id), receiver_id, request)
189    }
190
191    pub fn request_internal<T: RequestMessage>(
192        self: &Arc<Self>,
193        original_sender_id: Option<ConnectionId>,
194        receiver_id: ConnectionId,
195        request: T,
196    ) -> impl Future<Output = Result<T::Response>> {
197        let this = self.clone();
198        let (tx, mut rx) = mpsc::channel(1);
199        async move {
200            let mut connection = this.connection(receiver_id).await?;
201            let message_id = connection
202                .next_message_id
203                .fetch_add(1, atomic::Ordering::SeqCst);
204            connection
205                .response_channels
206                .lock()
207                .await
208                .insert(message_id, tx);
209            connection
210                .outgoing_tx
211                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
212                .await?;
213            let response = rx
214                .recv()
215                .await
216                .ok_or_else(|| anyhow!("connection was closed"))?;
217            T::Response::from_envelope(response)
218                .ok_or_else(|| anyhow!("received response of the wrong type"))
219        }
220    }
221
222    pub fn send<T: EnvelopedMessage>(
223        self: &Arc<Self>,
224        receiver_id: ConnectionId,
225        message: T,
226    ) -> impl Future<Output = Result<()>> {
227        let this = self.clone();
228        async move {
229            let mut connection = this.connection(receiver_id).await?;
230            let message_id = connection
231                .next_message_id
232                .fetch_add(1, atomic::Ordering::SeqCst);
233            connection
234                .outgoing_tx
235                .send(message.into_envelope(message_id, None, None))
236                .await?;
237            Ok(())
238        }
239    }
240
241    pub fn forward_send<T: EnvelopedMessage>(
242        self: &Arc<Self>,
243        sender_id: ConnectionId,
244        receiver_id: ConnectionId,
245        message: T,
246    ) -> impl Future<Output = Result<()>> {
247        let this = self.clone();
248        async move {
249            let mut connection = this.connection(receiver_id).await?;
250            let message_id = connection
251                .next_message_id
252                .fetch_add(1, atomic::Ordering::SeqCst);
253            connection
254                .outgoing_tx
255                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
256                .await?;
257            Ok(())
258        }
259    }
260
261    pub fn respond<T: RequestMessage>(
262        self: &Arc<Self>,
263        receipt: Receipt<T>,
264        response: T::Response,
265    ) -> impl Future<Output = Result<()>> {
266        let this = self.clone();
267        async move {
268            let mut connection = this.connection(receipt.sender_id).await?;
269            let message_id = connection
270                .next_message_id
271                .fetch_add(1, atomic::Ordering::SeqCst);
272            connection
273                .outgoing_tx
274                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
275                .await?;
276            Ok(())
277        }
278    }
279
280    fn connection(
281        self: &Arc<Self>,
282        connection_id: ConnectionId,
283    ) -> impl Future<Output = Result<Connection>> {
284        let this = self.clone();
285        async move {
286            let connections = this.connections.read().await;
287            let connection = connections
288                .get(&connection_id)
289                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
290            Ok(connection.clone())
291        }
292    }
293}
294
295impl<Conn> ConnectionHandler<Conn>
296where
297    Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
298{
299    pub async fn run(mut self) -> Result<()> {
300        loop {
301            let read_message = self.reader.read_message().fuse();
302            futures::pin_mut!(read_message);
303            loop {
304                futures::select! {
305                    incoming = read_message => match incoming {
306                        Ok(incoming) => {
307                            Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
308                            break;
309                        }
310                        Err(error) => {
311                            self.response_channels.lock().await.clear();
312                            Err(error).context("received invalid RPC message")?;
313                        }
314                    },
315                    outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
316                        Some(outgoing) => {
317                            if let Err(result) = self.writer.write_message(&outgoing).await {
318                                self.response_channels.lock().await.clear();
319                                Err(result).context("failed to write RPC message")?;
320                            }
321                        }
322                        None => return Ok(()),
323                    }
324                }
325            }
326        }
327    }
328
329    pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
330        let envelope = self.reader.read_message().await?;
331        let original_sender_id = envelope.original_sender_id;
332        let message_id = envelope.id;
333        let payload =
334            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
335        Ok(TypedEnvelope {
336            sender_id: self.connection_id,
337            original_sender_id: original_sender_id.map(PeerId),
338            message_id,
339            payload,
340        })
341    }
342
343    async fn handle_incoming_message(
344        message: proto::Envelope,
345        peer: &Arc<Peer>,
346        connection_id: ConnectionId,
347        response_channels: &ResponseChannels,
348    ) {
349        if let Some(responding_to) = message.responding_to {
350            let channel = response_channels.lock().await.remove(&responding_to);
351            if let Some(mut tx) = channel {
352                tx.send(message).await.ok();
353            } else {
354                log::warn!("received RPC response to unknown request {}", responding_to);
355            }
356        } else {
357            let mut envelope = Some(message);
358            let mut handler_index = None;
359            let mut handler_was_dropped = false;
360            for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
361                if let Some(future) = handler(&mut envelope, connection_id) {
362                    handler_was_dropped = future.await;
363                    handler_index = Some(i);
364                    break;
365                }
366            }
367
368            if let Some(handler_index) = handler_index {
369                if handler_was_dropped {
370                    drop(peer.message_handlers.write().await.remove(handler_index));
371                }
372            } else {
373                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
374            }
375        }
376    }
377}
378
379impl fmt::Display for ConnectionId {
380    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381        self.0.fmt(f)
382    }
383}
384
385impl fmt::Display for PeerId {
386    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387        self.0.fmt(f)
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use smol::{
395        io::AsyncWriteExt,
396        net::unix::{UnixListener, UnixStream},
397    };
398    use std::io;
399    use tempdir::TempDir;
400
401    #[test]
402    fn test_request_response() {
403        smol::block_on(async move {
404            // create socket
405            let socket_dir_path = TempDir::new("test-request-response").unwrap();
406            let socket_path = socket_dir_path.path().join("test.sock");
407            let listener = UnixListener::bind(&socket_path).unwrap();
408
409            // create 2 clients connected to 1 server
410            let server = Peer::new();
411            let client1 = Peer::new();
412            let client2 = Peer::new();
413            let (client1_conn_id, task1) = client1
414                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
415                .await;
416            let (client2_conn_id, task2) = client2
417                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
418                .await;
419            let (_, task3) = server
420                .add_connection(listener.accept().await.unwrap().0)
421                .await;
422            let (_, task4) = server
423                .add_connection(listener.accept().await.unwrap().0)
424                .await;
425            smol::spawn(task1.run()).detach();
426            smol::spawn(task2.run()).detach();
427            smol::spawn(task3.run()).detach();
428            smol::spawn(task4.run()).detach();
429
430            // define the expected requests and responses
431            let request1 = proto::Auth {
432                user_id: 1,
433                access_token: "token-1".to_string(),
434            };
435            let response1 = proto::AuthResponse {
436                credentials_valid: true,
437            };
438            let request2 = proto::Auth {
439                user_id: 2,
440                access_token: "token-2".to_string(),
441            };
442            let response2 = proto::AuthResponse {
443                credentials_valid: false,
444            };
445            let request3 = proto::OpenBuffer {
446                worktree_id: 1,
447                path: "path/two".to_string(),
448            };
449            let response3 = proto::OpenBufferResponse {
450                buffer: Some(proto::Buffer {
451                    id: 2,
452                    content: "path/two content".to_string(),
453                    history: vec![],
454                    selections: vec![],
455                }),
456            };
457            let request4 = proto::OpenBuffer {
458                worktree_id: 2,
459                path: "path/one".to_string(),
460            };
461            let response4 = proto::OpenBufferResponse {
462                buffer: Some(proto::Buffer {
463                    id: 1,
464                    content: "path/one content".to_string(),
465                    history: vec![],
466                    selections: vec![],
467                }),
468            };
469
470            // on the server, respond to two requests for each client
471            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
472            let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
473            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
474            smol::spawn({
475                let request1 = request1.clone();
476                let request2 = request2.clone();
477                let request3 = request3.clone();
478                let request4 = request4.clone();
479                let response1 = response1.clone();
480                let response2 = response2.clone();
481                let response3 = response3.clone();
482                let response4 = response4.clone();
483                async move {
484                    let msg = auth_rx.recv().await.unwrap();
485                    assert_eq!(msg.payload, request1);
486                    server
487                        .respond(msg.receipt(), response1.clone())
488                        .await
489                        .unwrap();
490
491                    let msg = auth_rx.recv().await.unwrap();
492                    assert_eq!(msg.payload, request2.clone());
493                    server
494                        .respond(msg.receipt(), response2.clone())
495                        .await
496                        .unwrap();
497
498                    let msg = open_buffer_rx.recv().await.unwrap();
499                    assert_eq!(msg.payload, request3.clone());
500                    server
501                        .respond(msg.receipt(), response3.clone())
502                        .await
503                        .unwrap();
504
505                    let msg = open_buffer_rx.recv().await.unwrap();
506                    assert_eq!(msg.payload, request4.clone());
507                    server
508                        .respond(msg.receipt(), response4.clone())
509                        .await
510                        .unwrap();
511
512                    server_done_tx.send(()).await.unwrap();
513                }
514            })
515            .detach();
516
517            assert_eq!(
518                client1.request(client1_conn_id, request1).await.unwrap(),
519                response1
520            );
521            assert_eq!(
522                client2.request(client2_conn_id, request2).await.unwrap(),
523                response2
524            );
525            assert_eq!(
526                client2.request(client2_conn_id, request3).await.unwrap(),
527                response3
528            );
529            assert_eq!(
530                client1.request(client1_conn_id, request4).await.unwrap(),
531                response4
532            );
533
534            client1.disconnect(client1_conn_id).await;
535            client2.disconnect(client1_conn_id).await;
536
537            server_done_rx.recv().await.unwrap();
538        });
539    }
540
541    #[test]
542    fn test_disconnect() {
543        smol::block_on(async move {
544            let socket_dir_path = TempDir::new("drop-client").unwrap();
545            let socket_path = socket_dir_path.path().join(".sock");
546            let listener = UnixListener::bind(&socket_path).unwrap();
547            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
548            let (mut server_conn, _) = listener.accept().await.unwrap();
549
550            let client = Peer::new();
551            let (connection_id, handler) = client.add_connection(client_conn).await;
552            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
553                postage::barrier::channel();
554            smol::spawn(async move {
555                handler.run().await.ok();
556                incoming_messages_ended_tx.send(()).await.unwrap();
557            })
558            .detach();
559            client.disconnect(connection_id).await;
560
561            incoming_messages_ended_rx.recv().await;
562
563            let err = server_conn.write(&[]).await.unwrap_err();
564            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
565        });
566    }
567
568    #[test]
569    fn test_io_error() {
570        smol::block_on(async move {
571            let socket_dir_path = TempDir::new("io-error").unwrap();
572            let socket_path = socket_dir_path.path().join(".sock");
573            let _listener = UnixListener::bind(&socket_path).unwrap();
574            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
575            client_conn.close().await.unwrap();
576
577            let client = Peer::new();
578            let (connection_id, handler) = client.add_connection(client_conn).await;
579            smol::spawn(handler.run()).detach();
580
581            let err = client
582                .request(
583                    connection_id,
584                    proto::Auth {
585                        user_id: 42,
586                        access_token: "token".to_string(),
587                    },
588                )
589                .await
590                .unwrap_err();
591            assert_eq!(err.to_string(), "connection was closed");
592        });
593    }
594}