websocket.rs

 1use std::pin::Pin;
 2use std::time::Duration;
 3
 4use anyhow::Result;
 5use cloud_api_types::websocket_protocol::MessageToClient;
 6use futures::channel::mpsc::unbounded;
 7use futures::stream::{SplitSink, SplitStream};
 8use futures::{FutureExt as _, SinkExt as _, Stream, StreamExt as _, TryStreamExt as _, pin_mut};
 9use gpui::{App, BackgroundExecutor, Task};
10use yawc::WebSocket;
11use yawc::frame::{FrameView, OpCode};
12
13const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
14
15pub type MessageStream = Pin<Box<dyn Stream<Item = Result<MessageToClient>>>>;
16
17pub struct Connection {
18    tx: SplitSink<WebSocket, FrameView>,
19    rx: SplitStream<WebSocket>,
20}
21
22impl Connection {
23    pub fn new(ws: WebSocket) -> Self {
24        let (tx, rx) = ws.split();
25
26        Self { tx, rx }
27    }
28
29    pub fn spawn(self, cx: &App) -> (MessageStream, Task<()>) {
30        let (mut tx, rx) = (self.tx, self.rx);
31
32        let (message_tx, message_rx) = unbounded();
33
34        let handle_io = |executor: BackgroundExecutor| async move {
35            // Send messages on this frequency so the connection isn't closed.
36            let keepalive_timer = executor.timer(KEEPALIVE_INTERVAL).fuse();
37            futures::pin_mut!(keepalive_timer);
38
39            let rx = rx.fuse();
40            pin_mut!(rx);
41
42            loop {
43                futures::select_biased! {
44                    _ = keepalive_timer => {
45                        let _ = tx.send(FrameView::ping(Vec::new())).await;
46
47                        keepalive_timer.set(executor.timer(KEEPALIVE_INTERVAL).fuse());
48                    }
49                    frame = rx.next() => {
50                        let Some(frame) = frame else {
51                            break;
52                        };
53
54                        match frame.opcode {
55                            OpCode::Binary => {
56                                let message_result = MessageToClient::deserialize(&frame.payload);
57                                message_tx.unbounded_send(message_result).ok();
58                            }
59                            OpCode::Close => {
60                                break;
61                            }
62                            _ => {}
63                        }
64                    }
65                }
66            }
67        };
68
69        let task = cx.spawn(async move |cx| handle_io(cx.background_executor().clone()).await);
70
71        (message_rx.into_stream().boxed(), task)
72    }
73}