1use std::pin::Pin;
2use std::time::Duration;
3
4use anyhow::Result;
5use cloud_api_types::websocket_protocol::MessageToClient;
6use futures::channel::mpsc::unbounded;
7use futures::stream::{SplitSink, SplitStream};
8use futures::{FutureExt as _, SinkExt as _, Stream, StreamExt as _, TryStreamExt as _, pin_mut};
9use gpui::{App, BackgroundExecutor, Task};
10use yawc::WebSocket;
11use yawc::frame::{FrameView, OpCode};
12
13const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
14
15pub type MessageStream = Pin<Box<dyn Stream<Item = Result<MessageToClient>>>>;
16
17pub struct Connection {
18 tx: SplitSink<WebSocket, FrameView>,
19 rx: SplitStream<WebSocket>,
20}
21
22impl Connection {
23 pub fn new(ws: WebSocket) -> Self {
24 let (tx, rx) = ws.split();
25
26 Self { tx, rx }
27 }
28
29 pub fn spawn(self, cx: &App) -> (MessageStream, Task<()>) {
30 let (mut tx, rx) = (self.tx, self.rx);
31
32 let (message_tx, message_rx) = unbounded();
33
34 let handle_io = |executor: BackgroundExecutor| async move {
35 // Send messages on this frequency so the connection isn't closed.
36 let keepalive_timer = executor.timer(KEEPALIVE_INTERVAL).fuse();
37 futures::pin_mut!(keepalive_timer);
38
39 let rx = rx.fuse();
40 pin_mut!(rx);
41
42 loop {
43 futures::select_biased! {
44 _ = keepalive_timer => {
45 let _ = tx.send(FrameView::ping(Vec::new())).await;
46
47 keepalive_timer.set(executor.timer(KEEPALIVE_INTERVAL).fuse());
48 }
49 frame = rx.next() => {
50 let Some(frame) = frame else {
51 break;
52 };
53
54 match frame.opcode {
55 OpCode::Binary => {
56 let message_result = MessageToClient::deserialize(&frame.payload);
57 message_tx.unbounded_send(message_result).ok();
58 }
59 OpCode::Close => {
60 break;
61 }
62 _ => {}
63 }
64 }
65 }
66 }
67 };
68
69 let task = cx.spawn(async move |cx| handle_io(cx.background_executor().clone()).await);
70
71 (message_rx.into_stream().boxed(), task)
72 }
73}