1use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
2use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
3use std::{io, task::Poll};
4
5pub struct Connection {
6 pub(crate) tx:
7 Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
8 pub(crate) rx: Box<
9 dyn 'static
10 + Send
11 + Unpin
12 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
13 >,
14}
15
16impl Connection {
17 pub fn new<S>(stream: S) -> Self
18 where
19 S: 'static
20 + Send
21 + Unpin
22 + futures::Sink<WebSocketMessage, Error = WebSocketError>
23 + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
24 {
25 let (tx, rx) = stream.split();
26 Self {
27 tx: Box::new(tx),
28 rx: Box::new(rx),
29 }
30 }
31
32 pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
33 self.tx.send(message).await
34 }
35
36 #[cfg(any(test, feature = "test-support"))]
37 pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
38 let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
39 postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
40
41 let (a_tx, a_rx) = Self::channel(kill_rx.clone());
42 let (b_tx, b_rx) = Self::channel(kill_rx);
43 (
44 Self { tx: a_tx, rx: b_rx },
45 Self { tx: b_tx, rx: a_rx },
46 kill_tx,
47 )
48 }
49
50 #[cfg(any(test, feature = "test-support"))]
51 fn channel(
52 kill_rx: postage::watch::Receiver<Option<()>>,
53 ) -> (
54 Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
55 Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
56 ) {
57 use futures::{future, SinkExt as _};
58 use io::{Error, ErrorKind};
59
60 let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
61 let tx = tx
62 .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
63 .with({
64 let kill_rx = kill_rx.clone();
65 move |msg| {
66 if kill_rx.borrow().is_none() {
67 future::ready(Ok(msg))
68 } else {
69 future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
70 }
71 }
72 });
73 let rx = KillableReceiver { kill_rx, rx };
74
75 (Box::new(tx), Box::new(rx))
76 }
77}
78
79struct KillableReceiver {
80 rx: mpsc::UnboundedReceiver<WebSocketMessage>,
81 kill_rx: postage::watch::Receiver<Option<()>>,
82}
83
84impl Stream for KillableReceiver {
85 type Item = Result<WebSocketMessage, WebSocketError>;
86
87 fn poll_next(
88 mut self: std::pin::Pin<&mut Self>,
89 cx: &mut std::task::Context<'_>,
90 ) -> Poll<Option<Self::Item>> {
91 if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
92 Poll::Ready(Some(Err(io::Error::new(
93 io::ErrorKind::Other,
94 "connection killed",
95 )
96 .into())))
97 } else {
98 self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
99 }
100 }
101}