1use async_tungstenite::tungstenite::Message as WebSocketMessage;
2use futures::{SinkExt as _, StreamExt as _};
3
4pub struct Connection {
5 pub(crate) tx:
6 Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
7 pub(crate) rx: Box<
8 dyn 'static
9 + Send
10 + Unpin
11 + futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>,
12 >,
13}
14
15impl Connection {
16 pub fn new<S>(stream: S) -> Self
17 where
18 S: 'static
19 + Send
20 + Unpin
21 + futures::Sink<WebSocketMessage, Error = anyhow::Error>
22 + futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>,
23 {
24 let (tx, rx) = stream.split();
25 Self {
26 tx: Box::new(tx),
27 rx: Box::new(rx),
28 }
29 }
30
31 pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), anyhow::Error> {
32 self.tx.send(message).await
33 }
34
35 #[cfg(any(test, feature = "test-support"))]
36 pub fn in_memory(
37 executor: std::sync::Arc<gpui::executor::Background>,
38 ) -> (Self, Self, std::sync::Arc<std::sync::atomic::AtomicBool>) {
39 use std::sync::{
40 atomic::{AtomicBool, Ordering::SeqCst},
41 Arc,
42 };
43
44 let killed = Arc::new(AtomicBool::new(false));
45 let (a_tx, a_rx) = channel(killed.clone(), executor.clone());
46 let (b_tx, b_rx) = channel(killed.clone(), executor);
47 return (
48 Self { tx: a_tx, rx: b_rx },
49 Self { tx: b_tx, rx: a_rx },
50 killed,
51 );
52
53 #[allow(clippy::type_complexity)]
54 fn channel(
55 killed: Arc<AtomicBool>,
56 executor: Arc<gpui::executor::Background>,
57 ) -> (
58 Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = anyhow::Error>>,
59 Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, anyhow::Error>>>,
60 ) {
61 use anyhow::anyhow;
62 use futures::channel::mpsc;
63 use std::io::{Error, ErrorKind};
64
65 let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
66
67 let tx = tx.sink_map_err(|error| anyhow!(error)).with({
68 let killed = killed.clone();
69 let executor = Arc::downgrade(&executor);
70 move |msg| {
71 let killed = killed.clone();
72 let executor = executor.clone();
73 Box::pin(async move {
74 if let Some(executor) = executor.upgrade() {
75 executor.simulate_random_delay().await;
76 }
77
78 // Writes to a half-open TCP connection will error.
79 if killed.load(SeqCst) {
80 std::io::Result::Err(Error::new(ErrorKind::Other, "connection lost"))?;
81 }
82
83 Ok(msg)
84 })
85 }
86 });
87
88 let rx = rx.then({
89 let killed = killed;
90 let executor = Arc::downgrade(&executor);
91 move |msg| {
92 let killed = killed.clone();
93 let executor = executor.clone();
94 Box::pin(async move {
95 if let Some(executor) = executor.upgrade() {
96 executor.simulate_random_delay().await;
97 }
98
99 // Reads from a half-open TCP connection will hang.
100 if killed.load(SeqCst) {
101 futures::future::pending::<()>().await;
102 }
103
104 Ok(msg)
105 })
106 }
107 });
108
109 (Box::new(tx), Box::new(rx))
110 }
111 }
112}