1use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
2use futures::{SinkExt as _, StreamExt as _};
3
4pub struct Conn {
5 pub(crate) tx:
6 Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
7 pub(crate) rx: Box<
8 dyn 'static
9 + Send
10 + Unpin
11 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
12 >,
13}
14
15impl Conn {
16 pub fn new<S>(stream: S) -> Self
17 where
18 S: 'static
19 + Send
20 + Unpin
21 + futures::Sink<WebSocketMessage, Error = WebSocketError>
22 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
23 {
24 let (tx, rx) = stream.split();
25 Self {
26 tx: Box::new(tx),
27 rx: Box::new(rx),
28 }
29 }
30
31 pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
32 self.tx.send(message).await
33 }
34
35 #[cfg(any(test, feature = "test-support"))]
36 pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
37 let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
38 postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
39
40 let (a_tx, a_rx) = Self::channel(kill_rx.clone());
41 let (b_tx, b_rx) = Self::channel(kill_rx);
42 (
43 Self { tx: a_tx, rx: b_rx },
44 Self { tx: b_tx, rx: a_rx },
45 kill_tx,
46 )
47 }
48
49 #[cfg(any(test, feature = "test-support"))]
50 fn channel(
51 kill_rx: postage::watch::Receiver<Option<()>>,
52 ) -> (
53 Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
54 Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
55 ) {
56 use futures::{future, stream, SinkExt as _, StreamExt as _};
57 use std::io::{Error, ErrorKind};
58
59 let (tx, rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
60 let tx = tx
61 .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
62 .with({
63 let kill_rx = kill_rx.clone();
64 move |msg| {
65 if kill_rx.borrow().is_none() {
66 future::ready(Ok(msg))
67 } else {
68 future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
69 }
70 }
71 });
72 let rx = stream::select(
73 rx.map(Ok),
74 kill_rx.filter_map(|kill| {
75 if kill.is_none() {
76 future::ready(None)
77 } else {
78 future::ready(Some(Err(
79 Error::new(ErrorKind::Other, "connection killed").into()
80 )))
81 }
82 }),
83 );
84
85 (Box::new(tx), Box::new(rx))
86 }
87}