1use async_tungstenite::tungstenite::Message as WebSocketMessage;
2use futures::{SinkExt as _, StreamExt as _};
3
4pub struct Connection {
5 pub(crate) tx:
6 Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
7 pub(crate) rx:
8 Box<dyn 'static + Send + Unpin + futures::Stream<Item = anyhow::Result<WebSocketMessage>>>,
9}
10
11impl Connection {
12 pub fn new<S>(stream: S) -> Self
13 where
14 S: 'static
15 + Send
16 + Unpin
17 + futures::Sink<WebSocketMessage, Error = anyhow::Error>
18 + futures::Stream<Item = anyhow::Result<WebSocketMessage>>,
19 {
20 let (tx, rx) = stream.split();
21 Self {
22 tx: Box::new(tx),
23 rx: Box::new(rx),
24 }
25 }
26
27 pub async fn send(&mut self, message: WebSocketMessage) -> anyhow::Result<()> {
28 self.tx.send(message).await
29 }
30
31 #[cfg(any(test, feature = "test-support"))]
32 pub fn in_memory(
33 executor: gpui::BackgroundExecutor,
34 ) -> (Self, Self, std::sync::Arc<std::sync::atomic::AtomicBool>) {
35 use std::sync::{
36 Arc,
37 atomic::{AtomicBool, Ordering::SeqCst},
38 };
39
40 let killed = Arc::new(AtomicBool::new(false));
41 let (a_tx, a_rx) = channel(killed.clone(), executor.clone());
42 let (b_tx, b_rx) = channel(killed.clone(), executor);
43 return (
44 Self { tx: a_tx, rx: b_rx },
45 Self { tx: b_tx, rx: a_rx },
46 killed,
47 );
48
49 #[allow(clippy::type_complexity)]
50 fn channel(
51 killed: Arc<AtomicBool>,
52 executor: gpui::BackgroundExecutor,
53 ) -> (
54 Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
55 Box<dyn Send + Unpin + futures::Stream<Item = anyhow::Result<WebSocketMessage>>>,
56 ) {
57 use anyhow::anyhow;
58 use futures::channel::mpsc;
59 use std::io::Error;
60
61 let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
62
63 let tx = tx.sink_map_err(|error| anyhow!(error)).with({
64 let killed = killed.clone();
65 let executor = executor.clone();
66 move |msg| {
67 let killed = killed.clone();
68 let executor = executor.clone();
69 Box::pin(async move {
70 executor.simulate_random_delay().await;
71
72 // Writes to a half-open TCP connection will error.
73 if killed.load(SeqCst) {
74 std::io::Result::Err(Error::other("connection lost"))?;
75 }
76
77 Ok(msg)
78 })
79 }
80 });
81
82 let rx = rx.then({
83 move |msg| {
84 let killed = killed.clone();
85 let executor = executor.clone();
86 Box::pin(async move {
87 executor.simulate_random_delay().await;
88
89 // Reads from a half-open TCP connection will hang.
90 if killed.load(SeqCst) {
91 futures::future::pending::<()>().await;
92 }
93
94 Ok(msg)
95 })
96 }
97 });
98
99 (Box::new(tx), Box::new(rx))
100 }
101 }
102}