peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use async_lock::{Mutex, RwLock};
  5use futures::FutureExt as _;
  6use postage::{
  7    mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use smol_timeout::TimeoutExt as _;
 11use std::{
 12    collections::HashMap,
 13    fmt,
 14    future::Future,
 15    marker::PhantomData,
 16    sync::{
 17        atomic::{self, AtomicU32},
 18        Arc,
 19    },
 20    time::Duration,
 21};
 22
 23#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 24pub struct ConnectionId(pub u32);
 25
 26impl fmt::Display for ConnectionId {
 27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 28        self.0.fmt(f)
 29    }
 30}
 31
 32#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 33pub struct PeerId(pub u32);
 34
 35impl fmt::Display for PeerId {
 36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 37        self.0.fmt(f)
 38    }
 39}
 40
 41pub struct Receipt<T> {
 42    pub sender_id: ConnectionId,
 43    pub message_id: u32,
 44    payload_type: PhantomData<T>,
 45}
 46
 47impl<T> Clone for Receipt<T> {
 48    fn clone(&self) -> Self {
 49        Self {
 50            sender_id: self.sender_id,
 51            message_id: self.message_id,
 52            payload_type: PhantomData,
 53        }
 54    }
 55}
 56
 57impl<T> Copy for Receipt<T> {}
 58
 59pub struct TypedEnvelope<T> {
 60    pub sender_id: ConnectionId,
 61    pub original_sender_id: Option<PeerId>,
 62    pub message_id: u32,
 63    pub payload: T,
 64}
 65
 66impl<T> TypedEnvelope<T> {
 67    pub fn original_sender_id(&self) -> Result<PeerId> {
 68        self.original_sender_id
 69            .ok_or_else(|| anyhow!("missing original_sender_id"))
 70    }
 71}
 72
 73impl<T: RequestMessage> TypedEnvelope<T> {
 74    pub fn receipt(&self) -> Receipt<T> {
 75        Receipt {
 76            sender_id: self.sender_id,
 77            message_id: self.message_id,
 78            payload_type: PhantomData,
 79        }
 80    }
 81}
 82
 83pub struct Peer {
 84    connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 85    next_connection_id: AtomicU32,
 86}
 87
 88#[derive(Clone)]
 89struct ConnectionState {
 90    outgoing_tx: mpsc::Sender<proto::Envelope>,
 91    next_message_id: Arc<AtomicU32>,
 92    response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
 93}
 94
 95const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 96
 97impl Peer {
 98    pub fn new() -> Arc<Self> {
 99        Arc::new(Self {
100            connections: Default::default(),
101            next_connection_id: Default::default(),
102        })
103    }
104
105    pub async fn add_connection(
106        self: &Arc<Self>,
107        connection: Connection,
108    ) -> (
109        ConnectionId,
110        impl Future<Output = anyhow::Result<()>> + Send,
111        mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
112    ) {
113        let connection_id = ConnectionId(
114            self.next_connection_id
115                .fetch_add(1, atomic::Ordering::SeqCst),
116        );
117        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
118        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
119        let connection_state = ConnectionState {
120            outgoing_tx,
121            next_message_id: Default::default(),
122            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
123        };
124        let mut writer = MessageStream::new(connection.tx);
125        let mut reader = MessageStream::new(connection.rx);
126
127        let this = self.clone();
128        let response_channels = connection_state.response_channels.clone();
129        let handle_io = async move {
130            let result = 'outer: loop {
131                let read_message = reader.read_message().fuse();
132                futures::pin_mut!(read_message);
133                loop {
134                    futures::select_biased! {
135                        incoming = read_message => match incoming {
136                            Ok(incoming) => {
137                                if let Some(responding_to) = incoming.responding_to {
138                                    let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
139                                    if let Some(mut tx) = channel {
140                                        tx.send(incoming).await.ok();
141                                    } else {
142                                        log::warn!("received RPC response to unknown request {}", responding_to);
143                                    }
144                                } else {
145                                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
146                                        if incoming_tx.send(envelope).await.is_err() {
147                                            break 'outer Ok(())
148                                        }
149                                    } else {
150                                        log::error!("unable to construct a typed envelope");
151                                    }
152                                }
153
154                                break;
155                            }
156                            Err(error) => {
157                                break 'outer Err(error).context("received invalid RPC message")
158                            }
159                        },
160                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
161                            Some(outgoing) => {
162                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
163                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
164                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
165                                    _ => {}
166                                }
167                            }
168                            None => break 'outer Ok(()),
169                        }
170                    }
171                }
172            };
173
174            response_channels.lock().await.take();
175            this.connections.write().await.remove(&connection_id);
176            result
177        };
178
179        self.connections
180            .write()
181            .await
182            .insert(connection_id, connection_state);
183
184        (connection_id, handle_io, incoming_rx)
185    }
186
187    pub async fn disconnect(&self, connection_id: ConnectionId) {
188        self.connections.write().await.remove(&connection_id);
189    }
190
191    pub async fn reset(&self) {
192        self.connections.write().await.clear();
193    }
194
195    pub fn request<T: RequestMessage>(
196        self: &Arc<Self>,
197        receiver_id: ConnectionId,
198        request: T,
199    ) -> impl Future<Output = Result<T::Response>> {
200        self.request_internal(None, receiver_id, request)
201    }
202
203    pub fn forward_request<T: RequestMessage>(
204        self: &Arc<Self>,
205        sender_id: ConnectionId,
206        receiver_id: ConnectionId,
207        request: T,
208    ) -> impl Future<Output = Result<T::Response>> {
209        self.request_internal(Some(sender_id), receiver_id, request)
210    }
211
212    pub fn request_internal<T: RequestMessage>(
213        self: &Arc<Self>,
214        original_sender_id: Option<ConnectionId>,
215        receiver_id: ConnectionId,
216        request: T,
217    ) -> impl Future<Output = Result<T::Response>> {
218        let this = self.clone();
219        let (tx, mut rx) = mpsc::channel(1);
220        async move {
221            let mut connection = this.connection_state(receiver_id).await?;
222            let message_id = connection
223                .next_message_id
224                .fetch_add(1, atomic::Ordering::SeqCst);
225            connection
226                .response_channels
227                .lock()
228                .await
229                .as_mut()
230                .ok_or_else(|| anyhow!("connection was closed"))?
231                .insert(message_id, tx);
232            connection
233                .outgoing_tx
234                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
235                .await
236                .map_err(|_| anyhow!("connection was closed"))?;
237            let response = rx
238                .recv()
239                .await
240                .ok_or_else(|| anyhow!("connection was closed"))?;
241            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
242                Err(anyhow!("request failed").context(error.message.clone()))
243            } else {
244                T::Response::from_envelope(response)
245                    .ok_or_else(|| anyhow!("received response of the wrong type"))
246            }
247        }
248    }
249
250    pub fn send<T: EnvelopedMessage>(
251        self: &Arc<Self>,
252        receiver_id: ConnectionId,
253        message: T,
254    ) -> impl Future<Output = Result<()>> {
255        let this = self.clone();
256        async move {
257            let mut connection = this.connection_state(receiver_id).await?;
258            let message_id = connection
259                .next_message_id
260                .fetch_add(1, atomic::Ordering::SeqCst);
261            connection
262                .outgoing_tx
263                .send(message.into_envelope(message_id, None, None))
264                .await?;
265            Ok(())
266        }
267    }
268
269    pub fn forward_send<T: EnvelopedMessage>(
270        self: &Arc<Self>,
271        sender_id: ConnectionId,
272        receiver_id: ConnectionId,
273        message: T,
274    ) -> impl Future<Output = Result<()>> {
275        let this = self.clone();
276        async move {
277            let mut connection = this.connection_state(receiver_id).await?;
278            let message_id = connection
279                .next_message_id
280                .fetch_add(1, atomic::Ordering::SeqCst);
281            connection
282                .outgoing_tx
283                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
284                .await?;
285            Ok(())
286        }
287    }
288
289    pub fn respond<T: RequestMessage>(
290        self: &Arc<Self>,
291        receipt: Receipt<T>,
292        response: T::Response,
293    ) -> impl Future<Output = Result<()>> {
294        let this = self.clone();
295        async move {
296            let mut connection = this.connection_state(receipt.sender_id).await?;
297            let message_id = connection
298                .next_message_id
299                .fetch_add(1, atomic::Ordering::SeqCst);
300            connection
301                .outgoing_tx
302                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
303                .await?;
304            Ok(())
305        }
306    }
307
308    pub fn respond_with_error<T: RequestMessage>(
309        self: &Arc<Self>,
310        receipt: Receipt<T>,
311        response: proto::Error,
312    ) -> impl Future<Output = Result<()>> {
313        let this = self.clone();
314        async move {
315            let mut connection = this.connection_state(receipt.sender_id).await?;
316            let message_id = connection
317                .next_message_id
318                .fetch_add(1, atomic::Ordering::SeqCst);
319            connection
320                .outgoing_tx
321                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
322                .await?;
323            Ok(())
324        }
325    }
326
327    fn connection_state(
328        self: &Arc<Self>,
329        connection_id: ConnectionId,
330    ) -> impl Future<Output = Result<ConnectionState>> {
331        let this = self.clone();
332        async move {
333            let connections = this.connections.read().await;
334            let connection = connections
335                .get(&connection_id)
336                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
337            Ok(connection.clone())
338        }
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use crate::TypedEnvelope;
346    use async_tungstenite::tungstenite::Message as WebSocketMessage;
347    use futures::StreamExt as _;
348
349    #[test]
350    fn test_request_response() {
351        smol::block_on(async move {
352            // create 2 clients connected to 1 server
353            let server = Peer::new();
354            let client1 = Peer::new();
355            let client2 = Peer::new();
356
357            let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
358            let (client1_conn_id, io_task1, _) =
359                client1.add_connection(client1_to_server_conn).await;
360            let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
361
362            let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
363            let (client2_conn_id, io_task3, _) =
364                client2.add_connection(client2_to_server_conn).await;
365            let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
366
367            smol::spawn(io_task1).detach();
368            smol::spawn(io_task2).detach();
369            smol::spawn(io_task3).detach();
370            smol::spawn(io_task4).detach();
371            smol::spawn(handle_messages(incoming1, server.clone())).detach();
372            smol::spawn(handle_messages(incoming2, server.clone())).detach();
373
374            assert_eq!(
375                client1
376                    .request(client1_conn_id, proto::Ping {},)
377                    .await
378                    .unwrap(),
379                proto::Ack {}
380            );
381
382            assert_eq!(
383                client2
384                    .request(client2_conn_id, proto::Ping {},)
385                    .await
386                    .unwrap(),
387                proto::Ack {}
388            );
389
390            assert_eq!(
391                client1
392                    .request(
393                        client1_conn_id,
394                        proto::OpenBuffer {
395                            worktree_id: 1,
396                            path: "path/one".to_string(),
397                        },
398                    )
399                    .await
400                    .unwrap(),
401                proto::OpenBufferResponse {
402                    buffer: Some(proto::Buffer {
403                        id: 101,
404                        content: "path/one content".to_string(),
405                        history: vec![],
406                        selections: vec![],
407                        diagnostics: None,
408                    }),
409                }
410            );
411
412            assert_eq!(
413                client2
414                    .request(
415                        client2_conn_id,
416                        proto::OpenBuffer {
417                            worktree_id: 2,
418                            path: "path/two".to_string(),
419                        },
420                    )
421                    .await
422                    .unwrap(),
423                proto::OpenBufferResponse {
424                    buffer: Some(proto::Buffer {
425                        id: 102,
426                        content: "path/two content".to_string(),
427                        history: vec![],
428                        selections: vec![],
429                        diagnostics: None,
430                    }),
431                }
432            );
433
434            client1.disconnect(client1_conn_id).await;
435            client2.disconnect(client1_conn_id).await;
436
437            async fn handle_messages(
438                mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
439                peer: Arc<Peer>,
440            ) -> Result<()> {
441                while let Some(envelope) = messages.next().await {
442                    let envelope = envelope.into_any();
443                    if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
444                        let receipt = envelope.receipt();
445                        peer.respond(receipt, proto::Ack {}).await?
446                    } else if let Some(envelope) =
447                        envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
448                    {
449                        let message = &envelope.payload;
450                        let receipt = envelope.receipt();
451                        let response = match message.path.as_str() {
452                            "path/one" => {
453                                assert_eq!(message.worktree_id, 1);
454                                proto::OpenBufferResponse {
455                                    buffer: Some(proto::Buffer {
456                                        id: 101,
457                                        content: "path/one content".to_string(),
458                                        history: vec![],
459                                        selections: vec![],
460                                        diagnostics: None,
461                                    }),
462                                }
463                            }
464                            "path/two" => {
465                                assert_eq!(message.worktree_id, 2);
466                                proto::OpenBufferResponse {
467                                    buffer: Some(proto::Buffer {
468                                        id: 102,
469                                        content: "path/two content".to_string(),
470                                        history: vec![],
471                                        selections: vec![],
472                                        diagnostics: None,
473                                    }),
474                                }
475                            }
476                            _ => {
477                                panic!("unexpected path {}", message.path);
478                            }
479                        };
480
481                        peer.respond(receipt, response).await?
482                    } else {
483                        panic!("unknown message type");
484                    }
485                }
486
487                Ok(())
488            }
489        });
490    }
491
492    #[test]
493    fn test_disconnect() {
494        smol::block_on(async move {
495            let (client_conn, mut server_conn, _) = Connection::in_memory();
496
497            let client = Peer::new();
498            let (connection_id, io_handler, mut incoming) =
499                client.add_connection(client_conn).await;
500
501            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
502            smol::spawn(async move {
503                io_handler.await.ok();
504                io_ended_tx.send(()).await.unwrap();
505            })
506            .detach();
507
508            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
509            smol::spawn(async move {
510                incoming.next().await;
511                messages_ended_tx.send(()).await.unwrap();
512            })
513            .detach();
514
515            client.disconnect(connection_id).await;
516
517            io_ended_rx.recv().await;
518            messages_ended_rx.recv().await;
519            assert!(server_conn
520                .send(WebSocketMessage::Binary(vec![]))
521                .await
522                .is_err());
523        });
524    }
525
526    #[test]
527    fn test_io_error() {
528        smol::block_on(async move {
529            let (client_conn, mut server_conn, _) = Connection::in_memory();
530
531            let client = Peer::new();
532            let (connection_id, io_handler, mut incoming) =
533                client.add_connection(client_conn).await;
534            smol::spawn(io_handler).detach();
535            smol::spawn(async move { incoming.next().await }).detach();
536
537            let response = smol::spawn(client.request(connection_id, proto::Ping {}));
538            let _request = server_conn.rx.next().await.unwrap().unwrap();
539
540            drop(server_conn);
541            assert_eq!(
542                response.await.unwrap_err().to_string(),
543                "connection was closed"
544            );
545        });
546    }
547}