peer.rs

  1use crate::proto::{self, EnvelopedMessage, MessageStream, RequestMessage};
  2use anyhow::{anyhow, Context, Result};
  3use async_lock::{Mutex, RwLock};
  4use futures::{future::BoxFuture, AsyncRead, AsyncWrite, FutureExt};
  5use postage::{
  6    mpsc,
  7    prelude::{Sink, Stream},
  8};
  9use std::{
 10    any::TypeId,
 11    collections::{HashMap, HashSet},
 12    fmt,
 13    future::Future,
 14    marker::PhantomData,
 15    sync::{
 16        atomic::{self, AtomicU32},
 17        Arc,
 18    },
 19};
 20
 21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 22pub struct ConnectionId(pub u32);
 23
 24#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 25pub struct PeerId(pub u32);
 26
 27type MessageHandler = Box<
 28    dyn Send + Sync + Fn(&mut Option<proto::Envelope>, ConnectionId) -> Option<BoxFuture<bool>>,
 29>;
 30
 31pub struct Receipt<T> {
 32    sender_id: ConnectionId,
 33    message_id: u32,
 34    payload_type: PhantomData<T>,
 35}
 36
 37pub struct TypedEnvelope<T> {
 38    pub sender_id: ConnectionId,
 39    original_sender_id: Option<PeerId>,
 40    pub message_id: u32,
 41    pub payload: T,
 42}
 43
 44impl<T> TypedEnvelope<T> {
 45    pub fn original_sender_id(&self) -> Result<PeerId> {
 46        self.original_sender_id
 47            .ok_or_else(|| anyhow!("missing original_sender_id"))
 48    }
 49}
 50
 51impl<T: RequestMessage> TypedEnvelope<T> {
 52    pub fn receipt(&self) -> Receipt<T> {
 53        Receipt {
 54            sender_id: self.sender_id,
 55            message_id: self.message_id,
 56            payload_type: PhantomData,
 57        }
 58    }
 59}
 60
 61pub struct Peer {
 62    connections: RwLock<HashMap<ConnectionId, Connection>>,
 63    message_handlers: RwLock<Vec<MessageHandler>>,
 64    handler_types: Mutex<HashSet<TypeId>>,
 65    next_connection_id: AtomicU32,
 66}
 67
 68#[derive(Clone)]
 69struct Connection {
 70    outgoing_tx: mpsc::Sender<proto::Envelope>,
 71    next_message_id: Arc<AtomicU32>,
 72    response_channels: ResponseChannels,
 73}
 74
 75pub struct ConnectionHandler<Conn> {
 76    peer: Arc<Peer>,
 77    connection_id: ConnectionId,
 78    response_channels: ResponseChannels,
 79    outgoing_rx: mpsc::Receiver<proto::Envelope>,
 80    reader: MessageStream<Conn>,
 81    writer: MessageStream<Conn>,
 82}
 83
 84type ResponseChannels = Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>;
 85
 86impl Peer {
 87    pub fn new() -> Arc<Self> {
 88        Arc::new(Self {
 89            connections: Default::default(),
 90            message_handlers: Default::default(),
 91            handler_types: Default::default(),
 92            next_connection_id: Default::default(),
 93        })
 94    }
 95
 96    pub async fn add_message_handler<T: EnvelopedMessage>(
 97        &self,
 98    ) -> mpsc::Receiver<TypedEnvelope<T>> {
 99        if !self.handler_types.lock().await.insert(TypeId::of::<T>()) {
100            panic!("duplicate handler type");
101        }
102
103        let (tx, rx) = mpsc::channel(256);
104        self.message_handlers
105            .write()
106            .await
107            .push(Box::new(move |envelope, connection_id| {
108                if envelope.as_ref().map_or(false, T::matches_envelope) {
109                    let envelope = Option::take(envelope).unwrap();
110                    let mut tx = tx.clone();
111                    Some(
112                        async move {
113                            tx.send(TypedEnvelope {
114                                sender_id: connection_id,
115                                original_sender_id: envelope.original_sender_id.map(PeerId),
116                                message_id: envelope.id,
117                                payload: T::from_envelope(envelope).unwrap(),
118                            })
119                            .await
120                            .is_err()
121                        }
122                        .boxed(),
123                    )
124                } else {
125                    None
126                }
127            }));
128        rx
129    }
130
131    pub async fn add_connection<Conn>(
132        self: &Arc<Self>,
133        conn: Conn,
134    ) -> (ConnectionId, ConnectionHandler<Conn>)
135    where
136        Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
137    {
138        let connection_id = ConnectionId(
139            self.next_connection_id
140                .fetch_add(1, atomic::Ordering::SeqCst),
141        );
142        let (outgoing_tx, outgoing_rx) = mpsc::channel(64);
143        let connection = Connection {
144            outgoing_tx,
145            next_message_id: Default::default(),
146            response_channels: Default::default(),
147        };
148        let handler = ConnectionHandler {
149            peer: self.clone(),
150            connection_id,
151            response_channels: connection.response_channels.clone(),
152            outgoing_rx,
153            reader: MessageStream::new(conn.clone()),
154            writer: MessageStream::new(conn),
155        };
156        self.connections
157            .write()
158            .await
159            .insert(connection_id, connection);
160        (connection_id, handler)
161    }
162
163    pub async fn disconnect(&self, connection_id: ConnectionId) {
164        self.connections.write().await.remove(&connection_id);
165    }
166
167    pub async fn reset(&self) {
168        self.connections.write().await.clear();
169        self.handler_types.lock().await.clear();
170        self.message_handlers.write().await.clear();
171    }
172
173    pub fn request<T: RequestMessage>(
174        self: &Arc<Self>,
175        receiver_id: ConnectionId,
176        request: T,
177    ) -> impl Future<Output = Result<T::Response>> {
178        self.request_internal(None, receiver_id, request)
179    }
180
181    pub fn forward_request<T: RequestMessage>(
182        self: &Arc<Self>,
183        sender_id: ConnectionId,
184        receiver_id: ConnectionId,
185        request: T,
186    ) -> impl Future<Output = Result<T::Response>> {
187        self.request_internal(Some(sender_id), receiver_id, request)
188    }
189
190    pub fn request_internal<T: RequestMessage>(
191        self: &Arc<Self>,
192        original_sender_id: Option<ConnectionId>,
193        receiver_id: ConnectionId,
194        request: T,
195    ) -> impl Future<Output = Result<T::Response>> {
196        let this = self.clone();
197        let (tx, mut rx) = mpsc::channel(1);
198        async move {
199            let mut connection = this.connection(receiver_id).await?;
200            let message_id = connection
201                .next_message_id
202                .fetch_add(1, atomic::Ordering::SeqCst);
203            connection
204                .response_channels
205                .lock()
206                .await
207                .insert(message_id, tx);
208            connection
209                .outgoing_tx
210                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
211                .await?;
212            let response = rx
213                .recv()
214                .await
215                .ok_or_else(|| anyhow!("connection was closed"))?;
216            T::Response::from_envelope(response)
217                .ok_or_else(|| anyhow!("received response of the wrong type"))
218        }
219    }
220
221    pub fn send<T: EnvelopedMessage>(
222        self: &Arc<Self>,
223        receiver_id: ConnectionId,
224        message: T,
225    ) -> impl Future<Output = Result<()>> {
226        let this = self.clone();
227        async move {
228            let mut connection = this.connection(receiver_id).await?;
229            let message_id = connection
230                .next_message_id
231                .fetch_add(1, atomic::Ordering::SeqCst);
232            connection
233                .outgoing_tx
234                .send(message.into_envelope(message_id, None, None))
235                .await?;
236            Ok(())
237        }
238    }
239
240    pub fn forward_send<T: EnvelopedMessage>(
241        self: &Arc<Self>,
242        sender_id: ConnectionId,
243        receiver_id: ConnectionId,
244        message: T,
245    ) -> impl Future<Output = Result<()>> {
246        let this = self.clone();
247        async move {
248            let mut connection = this.connection(receiver_id).await?;
249            let message_id = connection
250                .next_message_id
251                .fetch_add(1, atomic::Ordering::SeqCst);
252            connection
253                .outgoing_tx
254                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
255                .await?;
256            Ok(())
257        }
258    }
259
260    pub fn respond<T: RequestMessage>(
261        self: &Arc<Self>,
262        receipt: Receipt<T>,
263        response: T::Response,
264    ) -> impl Future<Output = Result<()>> {
265        let this = self.clone();
266        async move {
267            let mut connection = this.connection(receipt.sender_id).await?;
268            let message_id = connection
269                .next_message_id
270                .fetch_add(1, atomic::Ordering::SeqCst);
271            connection
272                .outgoing_tx
273                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
274                .await?;
275            Ok(())
276        }
277    }
278
279    fn connection(
280        self: &Arc<Self>,
281        connection_id: ConnectionId,
282    ) -> impl Future<Output = Result<Connection>> {
283        let this = self.clone();
284        async move {
285            let connections = this.connections.read().await;
286            let connection = connections
287                .get(&connection_id)
288                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
289            Ok(connection.clone())
290        }
291    }
292}
293
294impl<Conn> ConnectionHandler<Conn>
295where
296    Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
297{
298    pub async fn run(mut self) -> Result<()> {
299        loop {
300            let read_message = self.reader.read_message().fuse();
301            futures::pin_mut!(read_message);
302            loop {
303                futures::select! {
304                    incoming = read_message => match incoming {
305                        Ok(incoming) => {
306                            Self::handle_incoming_message(incoming, &self.peer, self.connection_id, &self.response_channels).await;
307                            break;
308                        }
309                        Err(error) => {
310                            self.response_channels.lock().await.clear();
311                            Err(error).context("received invalid RPC message")?;
312                        }
313                    },
314                    outgoing = self.outgoing_rx.recv().fuse() => match outgoing {
315                        Some(outgoing) => {
316                            if let Err(result) = self.writer.write_message(&outgoing).await {
317                                self.response_channels.lock().await.clear();
318                                Err(result).context("failed to write RPC message")?;
319                            }
320                        }
321                        None => return Ok(()),
322                    }
323                }
324            }
325        }
326    }
327
328    pub async fn receive<M: EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
329        let envelope = self.reader.read_message().await?;
330        let original_sender_id = envelope.original_sender_id;
331        let message_id = envelope.id;
332        let payload =
333            M::from_envelope(envelope).ok_or_else(|| anyhow!("unexpected message type"))?;
334        Ok(TypedEnvelope {
335            sender_id: self.connection_id,
336            original_sender_id: original_sender_id.map(PeerId),
337            message_id,
338            payload,
339        })
340    }
341
342    async fn handle_incoming_message(
343        message: proto::Envelope,
344        peer: &Arc<Peer>,
345        connection_id: ConnectionId,
346        response_channels: &ResponseChannels,
347    ) {
348        if let Some(responding_to) = message.responding_to {
349            let channel = response_channels.lock().await.remove(&responding_to);
350            if let Some(mut tx) = channel {
351                tx.send(message).await.ok();
352            } else {
353                log::warn!("received RPC response to unknown request {}", responding_to);
354            }
355        } else {
356            let mut envelope = Some(message);
357            let mut handler_index = None;
358            let mut handler_was_dropped = false;
359            for (i, handler) in peer.message_handlers.read().await.iter().enumerate() {
360                if let Some(future) = handler(&mut envelope, connection_id) {
361                    handler_was_dropped = future.await;
362                    handler_index = Some(i);
363                    break;
364                }
365            }
366
367            if let Some(handler_index) = handler_index {
368                if handler_was_dropped {
369                    drop(peer.message_handlers.write().await.remove(handler_index));
370                }
371            } else {
372                log::warn!("unhandled message: {:?}", envelope.unwrap().payload);
373            }
374        }
375    }
376}
377
378impl<T> Clone for Receipt<T> {
379    fn clone(&self) -> Self {
380        Self {
381            sender_id: self.sender_id,
382            message_id: self.message_id,
383            payload_type: PhantomData,
384        }
385    }
386}
387
388impl<T> Copy for Receipt<T> {}
389
390impl fmt::Display for ConnectionId {
391    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
392        self.0.fmt(f)
393    }
394}
395
396impl fmt::Display for PeerId {
397    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
398        self.0.fmt(f)
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use postage::oneshot;
406    use smol::{
407        io::AsyncWriteExt,
408        net::unix::{UnixListener, UnixStream},
409    };
410    use std::io;
411    use tempdir::TempDir;
412
413    #[test]
414    fn test_request_response() {
415        smol::block_on(async move {
416            // create socket
417            let socket_dir_path = TempDir::new("test-request-response").unwrap();
418            let socket_path = socket_dir_path.path().join("test.sock");
419            let listener = UnixListener::bind(&socket_path).unwrap();
420
421            // create 2 clients connected to 1 server
422            let server = Peer::new();
423            let client1 = Peer::new();
424            let client2 = Peer::new();
425            let (client1_conn_id, task1) = client1
426                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
427                .await;
428            let (client2_conn_id, task2) = client2
429                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
430                .await;
431            let (_, task3) = server
432                .add_connection(listener.accept().await.unwrap().0)
433                .await;
434            let (_, task4) = server
435                .add_connection(listener.accept().await.unwrap().0)
436                .await;
437            smol::spawn(task1.run()).detach();
438            smol::spawn(task2.run()).detach();
439            smol::spawn(task3.run()).detach();
440            smol::spawn(task4.run()).detach();
441
442            // define the expected requests and responses
443            let request1 = proto::Auth {
444                user_id: 1,
445                access_token: "token-1".to_string(),
446            };
447            let response1 = proto::AuthResponse {
448                credentials_valid: true,
449            };
450            let request2 = proto::Auth {
451                user_id: 2,
452                access_token: "token-2".to_string(),
453            };
454            let response2 = proto::AuthResponse {
455                credentials_valid: false,
456            };
457            let request3 = proto::OpenBuffer {
458                worktree_id: 1,
459                path: "path/two".to_string(),
460            };
461            let response3 = proto::OpenBufferResponse {
462                buffer: Some(proto::Buffer {
463                    id: 2,
464                    content: "path/two content".to_string(),
465                    history: vec![],
466                    selections: vec![],
467                }),
468            };
469            let request4 = proto::OpenBuffer {
470                worktree_id: 2,
471                path: "path/one".to_string(),
472            };
473            let response4 = proto::OpenBufferResponse {
474                buffer: Some(proto::Buffer {
475                    id: 1,
476                    content: "path/one content".to_string(),
477                    history: vec![],
478                    selections: vec![],
479                }),
480            };
481
482            // on the server, respond to two requests for each client
483            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
484            let mut auth_rx = server.add_message_handler::<proto::Auth>().await;
485            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
486            smol::spawn({
487                let request1 = request1.clone();
488                let request2 = request2.clone();
489                let request3 = request3.clone();
490                let request4 = request4.clone();
491                let response1 = response1.clone();
492                let response2 = response2.clone();
493                let response3 = response3.clone();
494                let response4 = response4.clone();
495                async move {
496                    let msg = auth_rx.recv().await.unwrap();
497                    assert_eq!(msg.payload, request1);
498                    server
499                        .respond(msg.receipt(), response1.clone())
500                        .await
501                        .unwrap();
502
503                    let msg = auth_rx.recv().await.unwrap();
504                    assert_eq!(msg.payload, request2.clone());
505                    server
506                        .respond(msg.receipt(), response2.clone())
507                        .await
508                        .unwrap();
509
510                    let msg = open_buffer_rx.recv().await.unwrap();
511                    assert_eq!(msg.payload, request3.clone());
512                    server
513                        .respond(msg.receipt(), response3.clone())
514                        .await
515                        .unwrap();
516
517                    let msg = open_buffer_rx.recv().await.unwrap();
518                    assert_eq!(msg.payload, request4.clone());
519                    server
520                        .respond(msg.receipt(), response4.clone())
521                        .await
522                        .unwrap();
523
524                    server_done_tx.send(()).await.unwrap();
525                }
526            })
527            .detach();
528
529            assert_eq!(
530                client1.request(client1_conn_id, request1).await.unwrap(),
531                response1
532            );
533            assert_eq!(
534                client2.request(client2_conn_id, request2).await.unwrap(),
535                response2
536            );
537            assert_eq!(
538                client2.request(client2_conn_id, request3).await.unwrap(),
539                response3
540            );
541            assert_eq!(
542                client1.request(client1_conn_id, request4).await.unwrap(),
543                response4
544            );
545
546            client1.disconnect(client1_conn_id).await;
547            client2.disconnect(client1_conn_id).await;
548
549            server_done_rx.recv().await.unwrap();
550        });
551    }
552
553    #[test]
554    fn test_disconnect() {
555        smol::block_on(async move {
556            let socket_dir_path = TempDir::new("drop-client").unwrap();
557            let socket_path = socket_dir_path.path().join(".sock");
558            let listener = UnixListener::bind(&socket_path).unwrap();
559            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
560            let (mut server_conn, _) = listener.accept().await.unwrap();
561
562            let client = Peer::new();
563            let (connection_id, handler) = client.add_connection(client_conn).await;
564            let (mut incoming_messages_ended_tx, mut incoming_messages_ended_rx) =
565                postage::barrier::channel();
566            smol::spawn(async move {
567                handler.run().await.ok();
568                incoming_messages_ended_tx.send(()).await.unwrap();
569            })
570            .detach();
571            client.disconnect(connection_id).await;
572
573            incoming_messages_ended_rx.recv().await;
574
575            let err = server_conn.write(&[]).await.unwrap_err();
576            assert_eq!(err.kind(), io::ErrorKind::BrokenPipe);
577        });
578    }
579
580    #[test]
581    fn test_io_error() {
582        smol::block_on(async move {
583            let socket_dir_path = TempDir::new("io-error").unwrap();
584            let socket_path = socket_dir_path.path().join(".sock");
585            let _listener = UnixListener::bind(&socket_path).unwrap();
586            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
587            client_conn.close().await.unwrap();
588
589            let client = Peer::new();
590            let (connection_id, handler) = client.add_connection(client_conn).await;
591            smol::spawn(handler.run()).detach();
592
593            let err = client
594                .request(
595                    connection_id,
596                    proto::Auth {
597                        user_id: 42,
598                        access_token: "token".to_string(),
599                    },
600                )
601                .await
602                .unwrap_err();
603            assert_eq!(err.to_string(), "connection was closed");
604        });
605    }
606}