@@ -1,5 +1,6 @@
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use futures::{SinkExt as _, StreamExt as _};
+use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
+use std::{io, task::Poll};
pub struct Conn {
pub(crate) tx:
@@ -53,10 +54,10 @@ impl Conn {
Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
) {
- use futures::{future, stream, SinkExt as _, StreamExt as _};
- use std::io::{Error, ErrorKind};
+ use futures::{future, SinkExt as _};
+ use io::{Error, ErrorKind};
- let (tx, rx) = futures::channel::mpsc::unbounded::<WebSocketMessage>();
+ let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
let tx = tx
.sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
.with({
@@ -69,19 +70,32 @@ impl Conn {
}
}
});
- let rx = stream::select(
- rx.map(Ok),
- kill_rx.filter_map(|kill| {
- if kill.is_none() {
- future::ready(None)
- } else {
- future::ready(Some(Err(
- Error::new(ErrorKind::Other, "connection killed").into()
- )))
- }
- }),
- );
+ let rx = KillableReceiver { kill_rx, rx };
(Box::new(tx), Box::new(rx))
}
}
+
+struct KillableReceiver {
+ rx: mpsc::UnboundedReceiver<WebSocketMessage>,
+ kill_rx: postage::watch::Receiver<Option<()>>,
+}
+
+impl Stream for KillableReceiver {
+ type Item = Result<WebSocketMessage, WebSocketError>;
+
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
+ Poll::Ready(Some(Err(io::Error::new(
+ io::ErrorKind::Other,
+ "connection killed",
+ )
+ .into())))
+ } else {
+ self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
+ }
+ }
+}