conn.rs

 1use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
 2use futures::{SinkExt as _, StreamExt as _};
 3
 4pub struct Conn {
 5    pub(crate) tx:
 6        Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
 7    pub(crate) rx: Box<
 8        dyn 'static
 9            + Send
10            + Unpin
11            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
12    >,
13}
14
15impl Conn {
16    pub fn new<S>(stream: S) -> Self
17    where
18        S: 'static
19            + Send
20            + Unpin
21            + futures::Sink<WebSocketMessage, Error = WebSocketError>
22            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
23    {
24        let (tx, rx) = stream.split();
25        Self {
26            tx: Box::new(tx),
27            rx: Box::new(rx),
28        }
29    }
30
31    pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
32        self.tx.send(message).await
33    }
34
35    #[cfg(any(test, feature = "test-support"))]
36    pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
37        let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
38        postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
39
40        let (a_tx, a_rx) = Self::channel(kill_rx.clone());
41        let (b_tx, b_rx) = Self::channel(kill_rx);
42        (
43            Self { tx: a_tx, rx: b_rx },
44            Self { tx: b_tx, rx: a_rx },
45            kill_tx,
46        )
47    }
48
49    #[cfg(any(test, feature = "test-support"))]
50    fn channel(
51        kill_rx: postage::watch::Receiver<Option<()>>,
52    ) -> (
53        Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
54        Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
55    ) {
56        use futures::{future, stream, SinkExt as _, StreamExt as _};
57        use std::io::{Error, ErrorKind};
58
59        let (tx, rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
60        let tx = tx
61            .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
62            .with({
63                let kill_rx = kill_rx.clone();
64                move |msg| {
65                    if kill_rx.borrow().is_none() {
66                        future::ready(Ok(msg))
67                    } else {
68                        future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
69                    }
70                }
71            });
72        let rx = stream::select(
73            rx.map(Ok),
74            kill_rx.filter_map(|kill| {
75                if kill.is_none() {
76                    future::ready(None)
77                } else {
78                    future::ready(Some(Err(
79                        Error::new(ErrorKind::Other, "connection killed").into()
80                    )))
81                }
82            }),
83        );
84
85        (Box::new(tx), Box::new(rx))
86    }
87}