peer2.rs

  1use crate::{
  2    proto::{self, EnvelopedMessage, MessageStream, RequestMessage},
  3    ConnectionId, PeerId, Receipt,
  4};
  5use anyhow::{anyhow, Context, Result};
  6use async_lock::{Mutex, RwLock};
  7use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
  8use futures::{FutureExt, StreamExt};
  9use postage::{
 10    mpsc,
 11    prelude::{Sink as _, Stream as _},
 12};
 13use std::{
 14    any::Any,
 15    collections::HashMap,
 16    future::Future,
 17    sync::{
 18        atomic::{self, AtomicU32},
 19        Arc,
 20    },
 21};
 22
 23pub struct Peer {
 24    connections: RwLock<HashMap<ConnectionId, Connection>>,
 25    next_connection_id: AtomicU32,
 26}
 27
 28#[derive(Clone)]
 29struct Connection {
 30    outgoing_tx: mpsc::Sender<proto::Envelope>,
 31    next_message_id: Arc<AtomicU32>,
 32    response_channels: Arc<Mutex<HashMap<u32, mpsc::Sender<proto::Envelope>>>>,
 33}
 34
 35impl Peer {
 36    pub fn new() -> Arc<Self> {
 37        Arc::new(Self {
 38            connections: Default::default(),
 39            next_connection_id: Default::default(),
 40        })
 41    }
 42
 43    pub async fn add_connection<Conn>(
 44        self: &Arc<Self>,
 45        conn: Conn,
 46    ) -> (
 47        ConnectionId,
 48        impl Future<Output = anyhow::Result<()>> + Send,
 49        mpsc::Receiver<Box<dyn Any + Sync + Send>>,
 50    )
 51    where
 52        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
 53            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
 54            + Send
 55            + Unpin,
 56    {
 57        let (tx, rx) = conn.split();
 58        let connection_id = ConnectionId(
 59            self.next_connection_id
 60                .fetch_add(1, atomic::Ordering::SeqCst),
 61        );
 62        let (mut incoming_tx, incoming_rx) = mpsc::channel(64);
 63        let (outgoing_tx, mut outgoing_rx) = mpsc::channel(64);
 64        let connection = Connection {
 65            outgoing_tx,
 66            next_message_id: Default::default(),
 67            response_channels: Default::default(),
 68        };
 69        let mut writer = MessageStream::new(tx);
 70        let mut reader = MessageStream::new(rx);
 71
 72        let response_channels = connection.response_channels.clone();
 73        let handle_io = async move {
 74            loop {
 75                let read_message = reader.read_message().fuse();
 76                futures::pin_mut!(read_message);
 77                loop {
 78                    futures::select_biased! {
 79                        incoming = read_message => match incoming {
 80                            Ok(incoming) => {
 81                                if let Some(responding_to) = incoming.responding_to {
 82                                    let channel = response_channels.lock().await.remove(&responding_to);
 83                                    if let Some(mut tx) = channel {
 84                                        tx.send(incoming).await.ok();
 85                                    } else {
 86                                        log::warn!("received RPC response to unknown request {}", responding_to);
 87                                    }
 88                                } else {
 89                                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
 90                                        if incoming_tx.send(envelope).await.is_err() {
 91                                            response_channels.lock().await.clear();
 92                                            return Ok(())
 93                                        }
 94                                    } else {
 95                                        log::error!("unable to construct a typed envelope");
 96                                    }
 97                                }
 98
 99                                break;
100                            }
101                            Err(error) => {
102                                response_channels.lock().await.clear();
103                                Err(error).context("received invalid RPC message")?;
104                            }
105                        },
106                        outgoing = outgoing_rx.recv().fuse() => match outgoing {
107                            Some(outgoing) => {
108                                if let Err(result) = writer.write_message(&outgoing).await {
109                                    response_channels.lock().await.clear();
110                                    Err(result).context("failed to write RPC message")?;
111                                }
112                            }
113                            None => {
114                                response_channels.lock().await.clear();
115                                return Ok(())
116                            }
117                        }
118                    }
119                }
120            }
121        };
122
123        self.connections
124            .write()
125            .await
126            .insert(connection_id, connection);
127
128        (connection_id, handle_io, incoming_rx)
129    }
130
131    pub async fn disconnect(&self, connection_id: ConnectionId) {
132        self.connections.write().await.remove(&connection_id);
133    }
134
135    pub async fn reset(&self) {
136        self.connections.write().await.clear();
137    }
138
139    pub fn request<T: RequestMessage>(
140        self: &Arc<Self>,
141        receiver_id: ConnectionId,
142        request: T,
143    ) -> impl Future<Output = Result<T::Response>> {
144        self.request_internal(None, receiver_id, request)
145    }
146
147    pub fn forward_request<T: RequestMessage>(
148        self: &Arc<Self>,
149        sender_id: ConnectionId,
150        receiver_id: ConnectionId,
151        request: T,
152    ) -> impl Future<Output = Result<T::Response>> {
153        self.request_internal(Some(sender_id), receiver_id, request)
154    }
155
156    pub fn request_internal<T: RequestMessage>(
157        self: &Arc<Self>,
158        original_sender_id: Option<ConnectionId>,
159        receiver_id: ConnectionId,
160        request: T,
161    ) -> impl Future<Output = Result<T::Response>> {
162        let this = self.clone();
163        let (tx, mut rx) = mpsc::channel(1);
164        async move {
165            let mut connection = this.connection(receiver_id).await?;
166            let message_id = connection
167                .next_message_id
168                .fetch_add(1, atomic::Ordering::SeqCst);
169            connection
170                .response_channels
171                .lock()
172                .await
173                .insert(message_id, tx);
174            connection
175                .outgoing_tx
176                .send(request.into_envelope(message_id, None, original_sender_id.map(|id| id.0)))
177                .await
178                .map_err(|_| anyhow!("connection was closed"))?;
179            let response = rx
180                .recv()
181                .await
182                .ok_or_else(|| anyhow!("connection was closed"))?;
183            T::Response::from_envelope(response)
184                .ok_or_else(|| anyhow!("received response of the wrong type"))
185        }
186    }
187
188    pub fn send<T: EnvelopedMessage>(
189        self: &Arc<Self>,
190        receiver_id: ConnectionId,
191        message: T,
192    ) -> impl Future<Output = Result<()>> {
193        let this = self.clone();
194        async move {
195            let mut connection = this.connection(receiver_id).await?;
196            let message_id = connection
197                .next_message_id
198                .fetch_add(1, atomic::Ordering::SeqCst);
199            connection
200                .outgoing_tx
201                .send(message.into_envelope(message_id, None, None))
202                .await?;
203            Ok(())
204        }
205    }
206
207    pub fn forward_send<T: EnvelopedMessage>(
208        self: &Arc<Self>,
209        sender_id: ConnectionId,
210        receiver_id: ConnectionId,
211        message: T,
212    ) -> impl Future<Output = Result<()>> {
213        let this = self.clone();
214        async move {
215            let mut connection = this.connection(receiver_id).await?;
216            let message_id = connection
217                .next_message_id
218                .fetch_add(1, atomic::Ordering::SeqCst);
219            connection
220                .outgoing_tx
221                .send(message.into_envelope(message_id, None, Some(sender_id.0)))
222                .await?;
223            Ok(())
224        }
225    }
226
227    pub fn respond<T: RequestMessage>(
228        self: &Arc<Self>,
229        receipt: Receipt<T>,
230        response: T::Response,
231    ) -> impl Future<Output = Result<()>> {
232        let this = self.clone();
233        async move {
234            let mut connection = this.connection(receipt.sender_id).await?;
235            let message_id = connection
236                .next_message_id
237                .fetch_add(1, atomic::Ordering::SeqCst);
238            connection
239                .outgoing_tx
240                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
241                .await?;
242            Ok(())
243        }
244    }
245
246    fn connection(
247        self: &Arc<Self>,
248        connection_id: ConnectionId,
249    ) -> impl Future<Output = Result<Connection>> {
250        let this = self.clone();
251        async move {
252            let connections = this.connections.read().await;
253            let connection = connections
254                .get(&connection_id)
255                .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?;
256            Ok(connection.clone())
257        }
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::{test, TypedEnvelope};
265
266    #[test]
267    fn test_request_response() {
268        smol::block_on(async move {
269            // create 2 clients connected to 1 server
270            let server = Peer::new();
271            let client1 = Peer::new();
272            let client2 = Peer::new();
273
274            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
275            let (client1_conn_id, io_task1, _) =
276                client1.add_connection(client1_to_server_conn).await;
277            let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
278
279            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
280            let (client2_conn_id, io_task3, _) =
281                client2.add_connection(client2_to_server_conn).await;
282            let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
283
284            smol::spawn(io_task1).detach();
285            smol::spawn(io_task2).detach();
286            smol::spawn(io_task3).detach();
287            smol::spawn(io_task4).detach();
288            smol::spawn(handle_messages(incoming1, server.clone())).detach();
289            smol::spawn(handle_messages(incoming2, server.clone())).detach();
290
291            assert_eq!(
292                client1
293                    .request(client1_conn_id, proto::Ping { id: 1 },)
294                    .await
295                    .unwrap(),
296                proto::Pong { id: 1 }
297            );
298
299            assert_eq!(
300                client2
301                    .request(client2_conn_id, proto::Ping { id: 2 },)
302                    .await
303                    .unwrap(),
304                proto::Pong { id: 2 }
305            );
306
307            assert_eq!(
308                client1
309                    .request(
310                        client1_conn_id,
311                        proto::OpenBuffer {
312                            worktree_id: 1,
313                            path: "path/one".to_string(),
314                        },
315                    )
316                    .await
317                    .unwrap(),
318                proto::OpenBufferResponse {
319                    buffer: Some(proto::Buffer {
320                        id: 101,
321                        content: "path/one content".to_string(),
322                        history: vec![],
323                        selections: vec![],
324                    }),
325                }
326            );
327
328            assert_eq!(
329                client2
330                    .request(
331                        client2_conn_id,
332                        proto::OpenBuffer {
333                            worktree_id: 2,
334                            path: "path/two".to_string(),
335                        },
336                    )
337                    .await
338                    .unwrap(),
339                proto::OpenBufferResponse {
340                    buffer: Some(proto::Buffer {
341                        id: 102,
342                        content: "path/two content".to_string(),
343                        history: vec![],
344                        selections: vec![],
345                    }),
346                }
347            );
348
349            client1.disconnect(client1_conn_id).await;
350            client2.disconnect(client1_conn_id).await;
351
352            async fn handle_messages(
353                mut messages: mpsc::Receiver<Box<dyn Any + Sync + Send>>,
354                peer: Arc<Peer>,
355            ) -> Result<()> {
356                while let Some(envelope) = messages.next().await {
357                    if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
358                        let receipt = envelope.receipt();
359                        peer.respond(
360                            receipt,
361                            proto::Pong {
362                                id: envelope.payload.id,
363                            },
364                        )
365                        .await?
366                    } else if let Some(envelope) =
367                        envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
368                    {
369                        let message = &envelope.payload;
370                        let receipt = envelope.receipt();
371                        let response = match message.path.as_str() {
372                            "path/one" => {
373                                assert_eq!(message.worktree_id, 1);
374                                proto::OpenBufferResponse {
375                                    buffer: Some(proto::Buffer {
376                                        id: 101,
377                                        content: "path/one content".to_string(),
378                                        history: vec![],
379                                        selections: vec![],
380                                    }),
381                                }
382                            }
383                            "path/two" => {
384                                assert_eq!(message.worktree_id, 2);
385                                proto::OpenBufferResponse {
386                                    buffer: Some(proto::Buffer {
387                                        id: 102,
388                                        content: "path/two content".to_string(),
389                                        history: vec![],
390                                        selections: vec![],
391                                    }),
392                                }
393                            }
394                            _ => {
395                                panic!("unexpected path {}", message.path);
396                            }
397                        };
398
399                        peer.respond(receipt, response).await?
400                    } else {
401                        panic!("unknown message type");
402                    }
403                }
404
405                Ok(())
406            }
407        });
408    }
409
410    #[test]
411    fn test_disconnect() {
412        smol::block_on(async move {
413            let (client_conn, mut server_conn) = test::Channel::bidirectional();
414
415            let client = Peer::new();
416            let (connection_id, io_handler, mut incoming) =
417                client.add_connection(client_conn).await;
418
419            let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
420            smol::spawn(async move {
421                io_handler.await.ok();
422                io_ended_tx.send(()).await.unwrap();
423            })
424            .detach();
425
426            let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel();
427            smol::spawn(async move {
428                incoming.next().await;
429                messages_ended_tx.send(()).await.unwrap();
430            })
431            .detach();
432
433            client.disconnect(connection_id).await;
434
435            io_ended_rx.recv().await;
436            messages_ended_rx.recv().await;
437            assert!(
438                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
439                    .await
440                    .is_err()
441            );
442        });
443    }
444
445    #[test]
446    fn test_io_error() {
447        smol::block_on(async move {
448            let (client_conn, server_conn) = test::Channel::bidirectional();
449            drop(server_conn);
450
451            let client = Peer::new();
452            let (connection_id, io_handler, mut incoming) =
453                client.add_connection(client_conn).await;
454            smol::spawn(io_handler).detach();
455            smol::spawn(async move { incoming.next().await }).detach();
456
457            let err = client
458                .request(
459                    connection_id,
460                    proto::Auth {
461                        user_id: 42,
462                        access_token: "token".to_string(),
463                    },
464                )
465                .await
466                .unwrap_err();
467            assert_eq!(err.to_string(), "connection was closed");
468        });
469    }
470}