peer.rs

  1use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
  2use super::Connection;
  3use anyhow::{anyhow, Context, Result};
  4use async_lock::{Mutex, RwLock};
  5use futures::FutureExt as _;
  6use postage::{
  7    mpsc,
  8    prelude::{Sink as _, Stream as _},
  9};
 10use std::{
 11    collections::HashMap,
 12    fmt,
 13    future::Future,
 14    marker::PhantomData,
 15    sync::{
 16        atomic::{self, AtomicU32},
 17        Arc,
 18    },
 19};
 20
 21#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 22pub struct ConnectionId(pub u32);
 23
 24impl fmt::Display for ConnectionId {
 25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 26        self.0.fmt(f)
 27    }
 28}
 29
 30#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 31pub struct PeerId(pub u32);
 32
 33impl fmt::Display for PeerId {
 34    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 35        self.0.fmt(f)
 36    }
 37}
 38
 39pub struct Receipt<T> {
 40    pub sender_id: ConnectionId,
 41    pub message_id: u32,
 42    payload_type: PhantomData<T>,
 43}
 44
 45impl<T> Clone for Receipt<T> {
 46    fn clone(&self) -> Self {
 47        Self {
 48            sender_id: self.sender_id,
 49            message_id: self.message_id,
 50            payload_type: PhantomData,
 51        }
 52    }
 53}
 54
 55impl<T> Copy for Receipt<T> {}
 56
 57pub struct TypedEnvelope<T> {
 58    pub sender_id: ConnectionId,
 59    pub original_sender_id: Option<PeerId>,
 60    pub message_id: u32,
 61    pub payload: T,
 62}
 63
 64impl<T> TypedEnvelope<T> {
 65    pub fn original_sender_id(&self) -> Result<PeerId> {
 66        self.original_sender_id
 67            .ok_or_else(|| anyhow!("missing original_sender_id"))
 68    }
 69}
 70
 71impl<T: RequestMessage> TypedEnvelope<T> {
 72    pub fn receipt(&self) -> Receipt<T> {
 73        Receipt {
 74            sender_id: self.sender_id,
 75            message_id: self.message_id,
 76            payload_type: PhantomData,
 77        }
 78    }
 79}
 80
 81pub struct Peer {
 82    connections: RwLock<HashMap<ConnectionId, ConnectionState>>,
 83    next_connection_id: AtomicU32,
 84}
 85
 86#[derive(Clone)]
 87struct ConnectionState {
 88    outgoing_tx: mpsc::Sender<proto::Envelope>,
 89    next_message_id: Arc<AtomicU32>,
 90    response_channels: Arc<Mutex<Option<HashMap<u32, mpsc::Sender<proto::Envelope>>>>>,
 91}
 92
 93impl Peer {
 94    pub fn new() -> Arc<Self> {
 95        Arc::new(Self {
 96            connections: Default::default(),
 97            next_connection_id: Default::default(),
 98        })
 99    }
100
101    pub async fn add_connection(
102        self: &Arc<Self>,
103        connection: Connection,
104    ) -> (
105        ConnectionId,
106        impl Future<Output = anyhow::Result<()>> + Send,
107        mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
108    ) {
109        let connection_id = ConnectionId(
110            self.next_connection_id
111                .fetch_add(1, atomic::Ordering::SeqCst),
112        );
113        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
114        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
115        let connection_state = ConnectionState {
116            outgoing_tx,
117            next_message_id: Default::default(),
118            response_channels: Arc::new(Mutex::new(Some(Default::default()))),
119        };
120        let mut writer = MessageStream::new(connection.tx);
121        let mut reader = MessageStream::new(connection.rx);
122
123        let this = self.clone();
124        let response_channels = connection_state.response_channels.clone();
125        let handle_io = async move {
126            let result = 'outer: loop {
127                let read_message = reader.read_message().fuse();
128                futures::pin_mut!(read_message);
129                loop {
130                    futures::select_biased! {
131                        incoming = read_message => match incoming {
132                            Ok(incoming) => {
133                                if let Some(responding_to) = incoming.responding_to {
134                                    let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to);
135                                    if let Some(mut tx) = channel {
136                                        tx.send(incoming).await.ok();
137                                    } else {
138                                        log::warn!("received RPC response to unknown request {}", responding_to);
139                                    }
140                                } else {
141                                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
142                                        if incoming_tx.send(envelope).await.is_err() {
143                                            break 'outer Ok(())
144                                        }
145                                    } else {
146                                        log::error!("unable to construct a typed envelope");
147                                    }
148                                }
149
150                                break;
151                            }
152                            Err(error) => {
153                                break 'outer Err(error).context("received invalid RPC message")
154                            }
155                        },
156                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
157                            Some(outgoing) => {
158                                if let Err(result) = writer.write_message(&outgoing).await {
159                                    break 'outer Err(result).context("failed to write RPC message")
160                                }
161                            }
162                            None => break 'outer Ok(()),
163                        }
164                    }
165                }
166            };
167
168            response_channels.lock().await.take();
169            this.connections.write().await.remove(&connection_id);
170            result
171        };
172
173        self.connections
174            .write()
175            .await
176            .insert(connection_id, connection_state);
177
178        (connection_id, handle_io, incoming_rx)
179    }
180
181    pub async fn disconnect(&self, connection_id: ConnectionId) {
182        self.connections.write().await.remove(&connection_id);
183    }
184
185    pub async fn reset(&self) {
186        self.connections.write().await.clear();
187    }
188
189    pub fn request<T: RequestMessage>(
190        self: &Arc<Self>,
191        receiver_id: ConnectionId,
192        request: T,
193    ) -> impl Future<Output = Result<T::Response>> {
194        self.request_internal(None, receiver_id, request)
195    }
196
197    pub fn forward_request<T: RequestMessage>(
198        self: &Arc<Self>,
199        sender_id: ConnectionId,
200        receiver_id: ConnectionId,
201        request: T,
202    ) -> impl Future<Output = Result<T::Response>> {
203        self.request_internal(Some(sender_id), receiver_id, request)
204    }
205
206    pub fn request_internal<T: RequestMessage>(
207        self: &Arc<Self>,
208        original_sender_id: Option<ConnectionId>,
209        receiver_id: ConnectionId,
210        request: T,
211    ) -> impl Future<Output = Result<T::Response>> {
212        let this = self.clone();
213        let (tx, mut rx) = mpsc::channel(1);
214        async move {
215            let mut connection = this.connection_state(receiver_id).await?;
216            let message_id = connection
217                .next_message_id
218                .fetch_add(1, atomic::Ordering::SeqCst);
219            connection
220                .response_channels
221                .lock()
222                .await
223                .as_mut()
224                .ok_or_else(|| anyhow!("connection was closed"))?
225                .insert(message_id, tx);
226            connection
227                .outgoing_tx
228                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
229                .await
230                .map_err(|_| anyhow!("connection was closed"))?;
231            let response = rx
232                .recv()
233                .await
234                .ok_or_else(|| anyhow!("connection was closed"))?;
235            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
236                Err(anyhow!("request failed").context(error.message.clone()))
237            } else {
238                T::Response::from_envelope(response)
239                    .ok_or_else(|| anyhow!("received response of the wrong type"))
240            }
241        }
242    }
243
244    pub fn send<T: EnvelopedMessage>(
245        self: &Arc<Self>,
246        receiver_id: ConnectionId,
247        message: T,
248    ) -> impl Future<Output = Result<()>> {
249        let this = self.clone();
250        async move {
251            let mut connection = this.connection_state(receiver_id).await?;
252            let message_id = connection
253                .next_message_id
254                .fetch_add(1, atomic::Ordering::SeqCst);
255            connection
256                .outgoing_tx
257                .send(message.into_envelope(message_id, None, None))
258                .await?;
259            Ok(())
260        }
261    }
262
263    pub fn forward_send<T: EnvelopedMessage>(
264        self: &Arc<Self>,
265        sender_id: ConnectionId,
266        receiver_id: ConnectionId,
267        message: T,
268    ) -> impl Future<Output = Result<()>> {
269        let this = self.clone();
270        async move {
271            let mut connection = this.connection_state(receiver_id).await?;
272            let message_id = connection
273                .next_message_id
274                .fetch_add(1, atomic::Ordering::SeqCst);
275            connection
276                .outgoing_tx
277                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
278                .await?;
279            Ok(())
280        }
281    }
282
283    pub fn respond<T: RequestMessage>(
284        self: &Arc<Self>,
285        receipt: Receipt<T>,
286        response: T::Response,
287    ) -> impl Future<Output = Result<()>> {
288        let this = self.clone();
289        async move {
290            let mut connection = this.connection_state(receipt.sender_id).await?;
291            let message_id = connection
292                .next_message_id
293                .fetch_add(1, atomic::Ordering::SeqCst);
294            connection
295                .outgoing_tx
296                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
297                .await?;
298            Ok(())
299        }
300    }
301
302    pub fn respond_with_error<T: RequestMessage>(
303        self: &Arc<Self>,
304        receipt: Receipt<T>,
305        response: proto::Error,
306    ) -> impl Future<Output = Result<()>> {
307        let this = self.clone();
308        async move {
309            let mut connection = this.connection_state(receipt.sender_id).await?;
310            let message_id = connection
311                .next_message_id
312                .fetch_add(1, atomic::Ordering::SeqCst);
313            connection
314                .outgoing_tx
315                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
316                .await?;
317            Ok(())
318        }
319    }
320
321    fn connection_state(
322        self: &Arc<Self>,
323        connection_id: ConnectionId,
324    ) -> impl Future<Output = Result<ConnectionState>> {
325        let this = self.clone();
326        async move {
327            let connections = this.connections.read().await;
328            let connection = connections
329                .get(&connection_id)
330                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
331            Ok(connection.clone())
332        }
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::TypedEnvelope;
340    use async_tungstenite::tungstenite::Message as WebSocketMessage;
341    use futures::StreamExt as _;
342
343    #[test]
344    fn test_request_response() {
345        smol::block_on(async move {
346            // create 2 clients connected to 1 server
347            let server = Peer::new();
348            let client1 = Peer::new();
349            let client2 = Peer::new();
350
351            let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory();
352            let (client1_conn_id, io_task1, _) =
353                client1.add_connection(client1_to_server_conn).await;
354            let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
355
356            let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory();
357            let (client2_conn_id, io_task3, _) =
358                client2.add_connection(client2_to_server_conn).await;
359            let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
360
361            smol::spawn(io_task1).detach();
362            smol::spawn(io_task2).detach();
363            smol::spawn(io_task3).detach();
364            smol::spawn(io_task4).detach();
365            smol::spawn(handle_messages(incoming1, server.clone())).detach();
366            smol::spawn(handle_messages(incoming2, server.clone())).detach();
367
368            assert_eq!(
369                client1
370                    .request(client1_conn_id, proto::Ping {},)
371                    .await
372                    .unwrap(),
373                proto::Ack {}
374            );
375
376            assert_eq!(
377                client2
378                    .request(client2_conn_id, proto::Ping {},)
379                    .await
380                    .unwrap(),
381                proto::Ack {}
382            );
383
384            assert_eq!(
385                client1
386                    .request(
387                        client1_conn_id,
388                        proto::OpenBuffer {
389                            worktree_id: 1,
390                            path: "path/one".to_string(),
391                        },
392                    )
393                    .await
394                    .unwrap(),
395                proto::OpenBufferResponse {
396                    buffer: Some(proto::Buffer {
397                        id: 101,
398                        content: "path/one content".to_string(),
399                        history: vec![],
400                        selections: vec![],
401                        diagnostics: None,
402                    }),
403                }
404            );
405
406            assert_eq!(
407                client2
408                    .request(
409                        client2_conn_id,
410                        proto::OpenBuffer {
411                            worktree_id: 2,
412                            path: "path/two".to_string(),
413                        },
414                    )
415                    .await
416                    .unwrap(),
417                proto::OpenBufferResponse {
418                    buffer: Some(proto::Buffer {
419                        id: 102,
420                        content: "path/two content".to_string(),
421                        history: vec![],
422                        selections: vec![],
423                        diagnostics: None,
424                    }),
425                }
426            );
427
428            client1.disconnect(client1_conn_id).await;
429            client2.disconnect(client1_conn_id).await;
430
431            async fn handle_messages(
432                mut messages: mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
433                peer: Arc<Peer>,
434            ) -> Result<()> {
435                while let Some(envelope) = messages.next().await {
436                    let envelope = envelope.into_any();
437                    if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
438                        let receipt = envelope.receipt();
439                        peer.respond(receipt, proto::Ack {}).await?
440                    } else if let Some(envelope) =
441                        envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
442                    {
443                        let message = &envelope.payload;
444                        let receipt = envelope.receipt();
445                        let response = match message.path.as_str() {
446                            "path/one" => {
447                                assert_eq!(message.worktree_id, 1);
448                                proto::OpenBufferResponse {
449                                    buffer: Some(proto::Buffer {
450                                        id: 101,
451                                        content: "path/one content".to_string(),
452                                        history: vec![],
453                                        selections: vec![],
454                                        diagnostics: None,
455                                    }),
456                                }
457                            }
458                            "path/two" => {
459                                assert_eq!(message.worktree_id, 2);
460                                proto::OpenBufferResponse {
461                                    buffer: Some(proto::Buffer {
462                                        id: 102,
463                                        content: "path/two content".to_string(),
464                                        history: vec![],
465                                        selections: vec![],
466                                        diagnostics: None,
467                                    }),
468                                }
469                            }
470                            _ => {
471                                panic!("unexpected path {}", message.path);
472                            }
473                        };
474
475                        peer.respond(receipt, response).await?
476                    } else {
477                        panic!("unknown message type");
478                    }
479                }
480
481                Ok(())
482            }
483        });
484    }
485
486    #[test]
487    fn test_disconnect() {
488        smol::block_on(async move {
489            let (client_conn, mut server_conn, _) = Connection::in_memory();
490
491            let client = Peer::new();
492            let (connection_id, io_handler, mut incoming) =
493                client.add_connection(client_conn).await;
494
495            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
496            smol::spawn(async move {
497                io_handler.await.ok();
498                io_ended_tx.send(()).await.unwrap();
499            })
500            .detach();
501
502            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
503            smol::spawn(async move {
504                incoming.next().await;
505                messages_ended_tx.send(()).await.unwrap();
506            })
507            .detach();
508
509            client.disconnect(connection_id).await;
510
511            io_ended_rx.recv().await;
512            messages_ended_rx.recv().await;
513            assert!(server_conn
514                .send(WebSocketMessage::Binary(vec![]))
515                .await
516                .is_err());
517        });
518    }
519
520    #[test]
521    fn test_io_error() {
522        smol::block_on(async move {
523            let (client_conn, mut server_conn, _) = Connection::in_memory();
524
525            let client = Peer::new();
526            let (connection_id, io_handler, mut incoming) =
527                client.add_connection(client_conn).await;
528            smol::spawn(io_handler).detach();
529            smol::spawn(async move { incoming.next().await }).detach();
530
531            let response = smol::spawn(client.request(connection_id, proto::Ping {}));
532            let _request = server_conn.rx.next().await.unwrap().unwrap();
533
534            drop(server_conn);
535            assert_eq!(
536                response.await.unwrap_err().to_string(),
537                "connection was closed"
538            );
539        });
540    }
541}