Merge commit '680b86b17c63b67f768bc5da5f34e5ccf056a0ce' into main

Max Brunsfeld created

Change summary

gpui/src/executor.rs     |  96 +++++++--
server/src/rpc.rs        | 315 ++++++++++++++++++++++++++++---
zed/src/channel.rs       | 175 ++++++++--------
zed/src/chat_panel.rs    |  32 +-
zed/src/editor/buffer.rs |   9 
zed/src/rpc.rs           | 418 +++++++++++++++++++++++++++++------------
zed/src/test.rs          | 129 ++++++++++++
zed/src/worktree.rs      |  21 +
zrpc/proto/zed.proto     |  14 
zrpc/src/conn.rs         | 101 ++++++++++
zrpc/src/lib.rs          |   5 
zrpc/src/peer.rs         |  65 ++---
zrpc/src/proto.rs        |  32 --
zrpc/src/test.rs         |  64 ------
14 files changed, 1,041 insertions(+), 435 deletions(-)

Detailed changes

gpui/src/executor.rs 🔗

@@ -3,8 +3,9 @@ use async_task::Runnable;
 pub use async_task::Task;
 use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
 use parking_lot::Mutex;
+use postage::{barrier, prelude::Stream as _};
 use rand::prelude::*;
-use smol::{channel, prelude::*, Executor};
+use smol::{channel, prelude::*, Executor, Timer};
 use std::{
     fmt::{self, Debug},
     marker::PhantomData,
@@ -18,7 +19,7 @@ use std::{
     },
     task::{Context, Poll},
     thread,
-    time::Duration,
+    time::{Duration, Instant},
 };
 use waker_fn::waker_fn;
 
@@ -49,6 +50,8 @@ struct DeterministicState {
     spawned_from_foreground: Vec<(Runnable, Backtrace)>,
     forbid_parking: bool,
     block_on_ticks: RangeInclusive<usize>,
+    now: Instant,
+    pending_timers: Vec<(Instant, barrier::Sender)>,
 }
 
 pub struct Deterministic {
@@ -67,6 +70,8 @@ impl Deterministic {
                 spawned_from_foreground: Default::default(),
                 forbid_parking: false,
                 block_on_ticks: 0..=1000,
+                now: Instant::now(),
+                pending_timers: Default::default(),
             })),
             parker: Default::default(),
         }
@@ -119,17 +124,39 @@ impl Deterministic {
         T: 'static,
         F: Future<Output = T> + 'static,
     {
+        let woken = Arc::new(AtomicBool::new(false));
+        let mut future = Box::pin(future);
+        loop {
+            if let Some(result) = self.run_internal(woken.clone(), &mut future) {
+                return result;
+            }
+
+            if !woken.load(SeqCst) && self.state.lock().forbid_parking {
+                panic!("deterministic executor parked after a call to forbid_parking");
+            }
+
+            woken.store(false, SeqCst);
+            self.parker.lock().park();
+        }
+    }
+
+    fn run_until_parked(&self) {
+        let woken = Arc::new(AtomicBool::new(false));
+        let future = std::future::pending::<()>();
         smol::pin!(future);
+        self.run_internal(woken, future);
+    }
 
+    pub fn run_internal<F, T>(&self, woken: Arc<AtomicBool>, mut future: F) -> Option<T>
+    where
+        T: 'static,
+        F: Future<Output = T> + Unpin,
+    {
         let unparker = self.parker.lock().unparker();
-        let woken = Arc::new(AtomicBool::new(false));
-        let waker = {
-            let woken = woken.clone();
-            waker_fn(move || {
-                woken.store(true, SeqCst);
-                unparker.unpark();
-            })
-        };
+        let waker = waker_fn(move || {
+            woken.store(true, SeqCst);
+            unparker.unpark();
+        });
 
         let mut cx = Context::from_waker(&waker);
         let mut trace = Trace::default();
@@ -163,23 +190,17 @@ impl Deterministic {
                 runnable.run();
             } else {
                 drop(state);
-                if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
-                    return result;
+                if let Poll::Ready(result) = future.poll(&mut cx) {
+                    return Some(result);
                 }
+
                 let state = self.state.lock();
                 if state.scheduled_from_foreground.is_empty()
                     && state.scheduled_from_background.is_empty()
                     && state.spawned_from_foreground.is_empty()
                 {
-                    if state.forbid_parking && !woken.load(SeqCst) {
-                        panic!("deterministic executor parked after a call to forbid_parking");
-                    }
-                    drop(state);
-                    woken.store(false, SeqCst);
-                    self.parker.lock().park();
+                    return None;
                 }
-
-                continue;
             }
         }
     }
@@ -407,6 +428,41 @@ impl Foreground {
         }
     }
 
+    pub async fn timer(&self, duration: Duration) {
+        match self {
+            Self::Deterministic(executor) => {
+                let (tx, mut rx) = barrier::channel();
+                {
+                    let mut state = executor.state.lock();
+                    let wakeup_at = state.now + duration;
+                    state.pending_timers.push((wakeup_at, tx));
+                }
+                rx.recv().await;
+            }
+            _ => {
+                Timer::after(duration).await;
+            }
+        }
+    }
+
+    pub fn advance_clock(&self, duration: Duration) {
+        match self {
+            Self::Deterministic(executor) => {
+                executor.run_until_parked();
+
+                let mut state = executor.state.lock();
+                state.now += duration;
+                let now = state.now;
+                let mut pending_timers = mem::take(&mut state.pending_timers);
+                drop(state);
+
+                pending_timers.retain(|(wakeup, _)| *wakeup > now);
+                executor.state.lock().pending_timers.extend(pending_timers);
+            }
+            _ => panic!("this method can only be called on a deterministic executor"),
+        }
+    }
+
     pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
         match self {
             Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,

server/src/rpc.rs 🔗

@@ -5,10 +5,7 @@ use super::{
 };
 use anyhow::anyhow;
 use async_std::{sync::RwLock, task};
-use async_tungstenite::{
-    tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
-    WebSocketStream,
-};
+use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
 use futures::{future::BoxFuture, FutureExt};
 use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
 use sha1::{Digest as _, Sha1};
@@ -30,7 +27,7 @@ use time::OffsetDateTime;
 use zrpc::{
     auth::random_token,
     proto::{self, AnyTypedEnvelope, EnvelopedMessage},
-    ConnectionId, Peer, TypedEnvelope,
+    Conn, ConnectionId, Peer, TypedEnvelope,
 };
 
 type ReplicaId = u16;
@@ -95,6 +92,7 @@ impl Server {
         };
 
         server
+            .add_handler(Server::ping)
             .add_handler(Server::share_worktree)
             .add_handler(Server::join_worktree)
             .add_handler(Server::update_worktree)
@@ -133,19 +131,12 @@ impl Server {
         self
     }
 
-    pub fn handle_connection<Conn>(
+    pub fn handle_connection(
         self: &Arc<Self>,
         connection: Conn,
         addr: String,
         user_id: UserId,
-    ) -> impl Future<Output = ()>
-    where
-        Conn: 'static
-            + futures::Sink<WebSocketMessage, Error = WebSocketError>
-            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
-            + Send
-            + Unpin,
-    {
+    ) -> impl Future<Output = ()> {
         let this = self.clone();
         async move {
             let (connection_id, handle_io, mut incoming_rx) =
@@ -254,6 +245,11 @@ impl Server {
         worktree_ids
     }
 
+    async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
+        self.peer.respond(request.receipt(), proto::Ack {}).await?;
+        Ok(())
+    }
+
     async fn share_worktree(
         self: Arc<Server>,
         mut request: TypedEnvelope<proto::ShareWorktree>,
@@ -503,7 +499,9 @@ impl Server {
         request: TypedEnvelope<proto::UpdateBuffer>,
     ) -> tide::Result<()> {
         self.broadcast_in_worktree(request.payload.worktree_id, &request)
-            .await
+            .await?;
+        self.peer.respond(request.receipt(), proto::Ack {}).await?;
+        Ok(())
     }
 
     async fn buffer_saved(
@@ -974,8 +972,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
             let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
             task::spawn(async move {
                 if let Some(stream) = upgrade_receiver.await {
-                    let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
-                    server.handle_connection(stream, addr, user_id).await;
+                    server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
                 }
             });
 
@@ -1009,17 +1006,25 @@ mod tests {
     };
     use async_std::{sync::RwLockReadGuard, task};
     use gpui::TestAppContext;
-    use postage::mpsc;
+    use parking_lot::Mutex;
+    use postage::{mpsc, watch};
     use serde_json::json;
     use sqlx::types::time::OffsetDateTime;
-    use std::{path::Path, sync::Arc, time::Duration};
+    use std::{
+        path::Path,
+        sync::{
+            atomic::{AtomicBool, Ordering::SeqCst},
+            Arc,
+        },
+        time::Duration,
+    };
     use zed::{
         channel::{Channel, ChannelDetails, ChannelList},
         editor::{Editor, Insert},
         fs::{FakeFs, Fs as _},
         language::LanguageRegistry,
-        rpc::Client,
-        settings, test,
+        rpc::{self, Client},
+        settings,
         user::UserStore,
         worktree::Worktree,
     };
@@ -1469,7 +1474,7 @@ mod tests {
             .await;
 
         // Drop client B's connection and ensure client A observes client B leaving the worktree.
-        client_b.disconnect().await.unwrap();
+        client_b.disconnect(&cx_b.to_async()).await.unwrap();
         worktree_a
             .condition(&cx_a, |tree, _| tree.peers().len() == 0)
             .await;
@@ -1675,11 +1680,206 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
+        cx_a.foreground().forbid_parking();
+
+        // Connect to a server as 2 clients.
+        let mut server = TestServer::start().await;
+        let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
+        let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
+        let mut status_b = client_b.status();
+
+        // Create an org that includes these 2 users.
+        let db = &server.app_state.db;
+        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
+        db.add_org_member(org_id, user_id_a, false).await.unwrap();
+        db.add_org_member(org_id, user_id_b, false).await.unwrap();
+
+        // Create a channel that includes all the users.
+        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
+        db.add_channel_member(channel_id, user_id_a, false)
+            .await
+            .unwrap();
+        db.add_channel_member(channel_id, user_id_b, false)
+            .await
+            .unwrap();
+        db.create_channel_message(
+            channel_id,
+            user_id_b,
+            "hello A, it's B.",
+            OffsetDateTime::now_utc(),
+        )
+        .await
+        .unwrap();
+
+        let user_store_a = Arc::new(UserStore::new(client_a.clone()));
+        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
+        channels_a
+            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
+            .await;
+
+        channels_a.read_with(&cx_a, |list, _| {
+            assert_eq!(
+                list.available_channels().unwrap(),
+                &[ChannelDetails {
+                    id: channel_id.to_proto(),
+                    name: "test-channel".to_string()
+                }]
+            )
+        });
+        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
+            this.get_channel(channel_id.to_proto(), cx).unwrap()
+        });
+        channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
+        channel_a
+            .condition(&cx_a, |channel, _| {
+                channel_messages(channel)
+                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
+            })
+            .await;
+
+        let user_store_b = Arc::new(UserStore::new(client_b.clone()));
+        let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
+        channels_b
+            .condition(&mut cx_b, |list, _| list.available_channels().is_some())
+            .await;
+        channels_b.read_with(&cx_b, |list, _| {
+            assert_eq!(
+                list.available_channels().unwrap(),
+                &[ChannelDetails {
+                    id: channel_id.to_proto(),
+                    name: "test-channel".to_string()
+                }]
+            )
+        });
+
+        let channel_b = channels_b.update(&mut cx_b, |this, cx| {
+            this.get_channel(channel_id.to_proto(), cx).unwrap()
+        });
+        channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
+        channel_b
+            .condition(&cx_b, |channel, _| {
+                channel_messages(channel)
+                    == [("user_b".to_string(), "hello A, it's B.".to_string())]
+            })
+            .await;
+
+        // Disconnect client B, ensuring we can still access its cached channel data.
+        server.forbid_connections();
+        server.disconnect_client(user_id_b);
+        while !matches!(
+            status_b.recv().await,
+            Some(rpc::Status::ReconnectionError { .. })
+        ) {}
+
+        channels_b.read_with(&cx_b, |channels, _| {
+            assert_eq!(
+                channels.available_channels().unwrap(),
+                [ChannelDetails {
+                    id: channel_id.to_proto(),
+                    name: "test-channel".to_string()
+                }]
+            )
+        });
+        channel_b.read_with(&cx_b, |channel, _| {
+            assert_eq!(
+                channel_messages(channel),
+                [("user_b".to_string(), "hello A, it's B.".to_string())]
+            )
+        });
+
+        // Send a message from client A while B is disconnected.
+        channel_a
+            .update(&mut cx_a, |channel, cx| {
+                channel
+                    .send_message("oh, hi B.".to_string(), cx)
+                    .unwrap()
+                    .detach();
+                let task = channel.send_message("sup".to_string(), cx).unwrap();
+                assert_eq!(
+                    channel
+                        .pending_messages()
+                        .iter()
+                        .map(|m| &m.body)
+                        .collect::<Vec<_>>(),
+                    &["oh, hi B.", "sup"]
+                );
+                task
+            })
+            .await
+            .unwrap();
+
+        // Give client B a chance to reconnect.
+        server.allow_connections();
+        cx_b.foreground().advance_clock(Duration::from_secs(10));
+
+        // Verify that B sees the new messages upon reconnection.
+        channel_b
+            .condition(&cx_b, |channel, _| {
+                channel_messages(channel)
+                    == [
+                        ("user_b".to_string(), "hello A, it's B.".to_string()),
+                        ("user_a".to_string(), "oh, hi B.".to_string()),
+                        ("user_a".to_string(), "sup".to_string()),
+                    ]
+            })
+            .await;
+
+        // Ensure client A and B can communicate normally after reconnection.
+        channel_a
+            .update(&mut cx_a, |channel, cx| {
+                channel.send_message("you online?".to_string(), cx).unwrap()
+            })
+            .await
+            .unwrap();
+        channel_b
+            .condition(&cx_b, |channel, _| {
+                channel_messages(channel)
+                    == [
+                        ("user_b".to_string(), "hello A, it's B.".to_string()),
+                        ("user_a".to_string(), "oh, hi B.".to_string()),
+                        ("user_a".to_string(), "sup".to_string()),
+                        ("user_a".to_string(), "you online?".to_string()),
+                    ]
+            })
+            .await;
+
+        channel_b
+            .update(&mut cx_b, |channel, cx| {
+                channel.send_message("yep".to_string(), cx).unwrap()
+            })
+            .await
+            .unwrap();
+        channel_a
+            .condition(&cx_a, |channel, _| {
+                channel_messages(channel)
+                    == [
+                        ("user_b".to_string(), "hello A, it's B.".to_string()),
+                        ("user_a".to_string(), "oh, hi B.".to_string()),
+                        ("user_a".to_string(), "sup".to_string()),
+                        ("user_a".to_string(), "you online?".to_string()),
+                        ("user_b".to_string(), "yep".to_string()),
+                    ]
+            })
+            .await;
+
+        fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
+            channel
+                .messages()
+                .cursor::<(), ()>()
+                .map(|m| (m.sender.github_login.clone(), m.body.clone()))
+                .collect()
+        }
+    }
+
     struct TestServer {
         peer: Arc<Peer>,
         app_state: Arc<AppState>,
         server: Arc<Server>,
         notifications: mpsc::Receiver<()>,
+        connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
+        forbid_connections: Arc<AtomicBool>,
         _test_db: TestDb,
     }
 
@@ -1695,6 +1895,8 @@ mod tests {
                 app_state,
                 server,
                 notifications: notifications.1,
+                connection_killers: Default::default(),
+                forbid_connections: Default::default(),
                 _test_db: test_db,
             }
         }
@@ -1704,20 +1906,67 @@ mod tests {
             cx: &mut TestAppContext,
             name: &str,
         ) -> (UserId, Arc<Client>) {
-            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
-            let client = Client::new();
-            let (client_conn, server_conn) = test::Channel::bidirectional();
-            cx.background()
-                .spawn(
-                    self.server
-                        .handle_connection(server_conn, name.to_string(), user_id),
-                )
-                .detach();
+            let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
+            let client_name = name.to_string();
+            let mut client = Client::new();
+            let server = self.server.clone();
+            let connection_killers = self.connection_killers.clone();
+            let forbid_connections = self.forbid_connections.clone();
+            Arc::get_mut(&mut client)
+                .unwrap()
+                .set_login_and_connect_callbacks(
+                    move |cx| {
+                        cx.spawn(|_| async move {
+                            let access_token = "the-token".to_string();
+                            Ok((client_user_id.0 as u64, access_token))
+                        })
+                    },
+                    move |user_id, access_token, cx| {
+                        assert_eq!(user_id, client_user_id.0 as u64);
+                        assert_eq!(access_token, "the-token");
+
+                        let server = server.clone();
+                        let connection_killers = connection_killers.clone();
+                        let forbid_connections = forbid_connections.clone();
+                        let client_name = client_name.clone();
+                        cx.spawn(move |cx| async move {
+                            if forbid_connections.load(SeqCst) {
+                                Err(anyhow!("server is forbidding connections"))
+                            } else {
+                                let (client_conn, server_conn, kill_conn) = Conn::in_memory();
+                                connection_killers.lock().insert(client_user_id, kill_conn);
+                                cx.background()
+                                    .spawn(server.handle_connection(
+                                        server_conn,
+                                        client_name,
+                                        client_user_id,
+                                    ))
+                                    .detach();
+                                Ok(client_conn)
+                            }
+                        })
+                    },
+                );
+
             client
-                .add_connection(user_id.to_proto(), client_conn, &cx.to_async())
+                .authenticate_and_connect(&cx.to_async())
                 .await
                 .unwrap();
-            (user_id, client)
+            (client_user_id, client)
+        }
+
+        fn disconnect_client(&self, user_id: UserId) {
+            if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
+                let _ = kill_conn.try_send(Some(()));
+            }
+        }
+
+        fn forbid_connections(&self) {
+            self.forbid_connections.store(true, SeqCst);
+        }
+
+        fn allow_connections(&self) {
+            self.forbid_connections.store(false, SeqCst);
         }
 
         async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {

zed/src/channel.rs 🔗

@@ -11,6 +11,7 @@ use gpui::{
 use postage::prelude::Stream;
 use std::{
     collections::{HashMap, HashSet},
+    mem,
     ops::Range,
     sync::Arc,
 };
@@ -71,7 +72,7 @@ pub enum ChannelListEvent {}
 
 #[derive(Clone, Debug, PartialEq)]
 pub enum ChannelEvent {
-    MessagesAdded {
+    MessagesUpdated {
         old_range: Range<usize>,
         new_count: usize,
     },
@@ -87,36 +88,47 @@ impl ChannelList {
         rpc: Arc<rpc::Client>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
-        let _task = cx.spawn(|this, mut cx| {
+        let _task = cx.spawn_weak(|this, mut cx| {
             let rpc = rpc.clone();
             async move {
-                let mut user_id = rpc.user_id();
-                loop {
-                    let available_channels = if user_id.recv().await.unwrap().is_some() {
-                        Some(
-                            rpc.request(proto::GetChannels {})
+                let mut status = rpc.status();
+                while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) {
+                    match status {
+                        rpc::Status::Connected { .. } => {
+                            let response = rpc
+                                .request(proto::GetChannels {})
                                 .await
-                                .context("failed to fetch available channels")?
-                                .channels
-                                .into_iter()
-                                .map(Into::into)
-                                .collect(),
-                        )
-                    } else {
-                        None
-                    };
-
-                    this.update(&mut cx, |this, cx| {
-                        if available_channels.is_none() {
-                            if this.available_channels.is_none() {
-                                return;
-                            }
-                            this.channels.clear();
+                                .context("failed to fetch available channels")?;
+                            this.update(&mut cx, |this, cx| {
+                                this.available_channels =
+                                    Some(response.channels.into_iter().map(Into::into).collect());
+
+                                let mut to_remove = Vec::new();
+                                for (channel_id, channel) in &this.channels {
+                                    if let Some(channel) = channel.upgrade(cx) {
+                                        channel.update(cx, |channel, cx| channel.rejoin(cx))
+                                    } else {
+                                        to_remove.push(*channel_id);
+                                    }
+                                }
+
+                                for channel_id in to_remove {
+                                    this.channels.remove(&channel_id);
+                                }
+                                cx.notify();
+                            });
                         }
-                        this.available_channels = available_channels;
-                        cx.notify();
-                    });
+                        rpc::Status::Disconnected { .. } => {
+                            this.update(&mut cx, |this, cx| {
+                                this.available_channels = None;
+                                this.channels.clear();
+                                cx.notify();
+                            });
+                        }
+                        _ => {}
+                    }
                 }
+                Ok(())
             }
             .log_err()
         });
@@ -285,6 +297,43 @@ impl Channel {
         false
     }
 
+    pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
+        let user_store = self.user_store.clone();
+        let rpc = self.rpc.clone();
+        let channel_id = self.details.id;
+        cx.spawn(|channel, mut cx| {
+            async move {
+                let response = rpc.request(proto::JoinChannel { channel_id }).await?;
+                let messages = messages_from_proto(response.messages, &user_store).await?;
+                let loaded_all_messages = response.done;
+
+                channel.update(&mut cx, |channel, cx| {
+                    if let Some((first_new_message, last_old_message)) =
+                        messages.first().zip(channel.messages.last())
+                    {
+                        if first_new_message.id > last_old_message.id {
+                            let old_messages = mem::take(&mut channel.messages);
+                            cx.emit(ChannelEvent::MessagesUpdated {
+                                old_range: 0..old_messages.summary().count,
+                                new_count: 0,
+                            });
+                            channel.loaded_all_messages = loaded_all_messages;
+                        }
+                    }
+
+                    channel.insert_messages(messages, cx);
+                    if loaded_all_messages {
+                        channel.loaded_all_messages = loaded_all_messages;
+                    }
+                });
+
+                Ok(())
+            }
+            .log_err()
+        })
+        .detach();
+    }
+
     pub fn message_count(&self) -> usize {
         self.messages.summary().count
     }
@@ -350,7 +399,7 @@ impl Channel {
             drop(old_cursor);
             self.messages = new_messages;
 
-            cx.emit(ChannelEvent::MessagesAdded {
+            cx.emit(ChannelEvent::MessagesUpdated {
                 old_range: start_ix..end_ix,
                 new_count,
             });
@@ -446,22 +495,21 @@ impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::test::FakeServer;
     use gpui::TestAppContext;
-    use postage::mpsc::Receiver;
-    use zrpc::{test::Channel, ConnectionId, Peer, Receipt};
 
     #[gpui::test]
     async fn test_channel_messages(mut cx: TestAppContext) {
         let user_id = 5;
-        let client = Client::new();
-        let mut server = FakeServer::for_client(user_id, &client, &cx).await;
+        let mut client = Client::new();
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
         let user_store = Arc::new(UserStore::new(client.clone()));
 
         let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
         channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
 
         // Get the available channels.
-        let get_channels = server.receive::<proto::GetChannels>().await;
+        let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
         server
             .respond(
                 get_channels.receipt(),
@@ -492,7 +540,7 @@ mod tests {
             })
             .unwrap();
         channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
-        let join_channel = server.receive::<proto::JoinChannel>().await;
+        let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
         server
             .respond(
                 join_channel.receipt(),
@@ -517,7 +565,7 @@ mod tests {
             .await;
 
         // Client requests all users for the received messages
-        let mut get_users = server.receive::<proto::GetUsers>().await;
+        let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
         get_users.payload.user_ids.sort();
         assert_eq!(get_users.payload.user_ids, vec![5, 6]);
         server
@@ -542,7 +590,7 @@ mod tests {
 
         assert_eq!(
             channel.next_event(&cx).await,
-            ChannelEvent::MessagesAdded {
+            ChannelEvent::MessagesUpdated {
                 old_range: 0..0,
                 new_count: 2,
             }
@@ -574,7 +622,7 @@ mod tests {
             .await;
 
         // Client requests user for message since they haven't seen them yet
-        let get_users = server.receive::<proto::GetUsers>().await;
+        let get_users = server.receive::<proto::GetUsers>().await.unwrap();
         assert_eq!(get_users.payload.user_ids, vec![7]);
         server
             .respond(
@@ -591,7 +639,7 @@ mod tests {
 
         assert_eq!(
             channel.next_event(&cx).await,
-            ChannelEvent::MessagesAdded {
+            ChannelEvent::MessagesUpdated {
                 old_range: 2..2,
                 new_count: 1,
             }
@@ -610,7 +658,7 @@ mod tests {
         channel.update(&mut cx, |channel, cx| {
             assert!(channel.load_more_messages(cx));
         });
-        let get_messages = server.receive::<proto::GetChannelMessages>().await;
+        let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
         assert_eq!(get_messages.payload.channel_id, 5);
         assert_eq!(get_messages.payload.before_message_id, 10);
         server
@@ -638,7 +686,7 @@ mod tests {
 
         assert_eq!(
             channel.next_event(&cx).await,
-            ChannelEvent::MessagesAdded {
+            ChannelEvent::MessagesUpdated {
                 old_range: 0..0,
                 new_count: 2,
             }
@@ -656,53 +704,4 @@ mod tests {
             );
         });
     }
-
-    struct FakeServer {
-        peer: Arc<Peer>,
-        incoming: Receiver<Box<dyn proto::AnyTypedEnvelope>>,
-        connection_id: ConnectionId,
-    }
-
-    impl FakeServer {
-        async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
-            let (client_conn, server_conn) = Channel::bidirectional();
-            let peer = Peer::new();
-            let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
-            cx.background().spawn(io).detach();
-
-            client
-                .add_connection(user_id, client_conn, &cx.to_async())
-                .await
-                .unwrap();
-
-            Self {
-                peer,
-                incoming,
-                connection_id,
-            }
-        }
-
-        async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
-            self.peer.send(self.connection_id, message).await.unwrap();
-        }
-
-        async fn receive<M: proto::EnvelopedMessage>(&mut self) -> TypedEnvelope<M> {
-            *self
-                .incoming
-                .recv()
-                .await
-                .unwrap()
-                .into_any()
-                .downcast::<TypedEnvelope<M>>()
-                .unwrap()
-        }
-
-        async fn respond<T: proto::RequestMessage>(
-            &self,
-            receipt: Receipt<T>,
-            response: T::Response,
-        ) {
-            self.peer.respond(receipt, response).await.unwrap()
-        }
-    }
 }

zed/src/chat_panel.rs 🔗

@@ -3,7 +3,7 @@ use std::sync::Arc;
 use crate::{
     channel::{Channel, ChannelEvent, ChannelList, ChannelMessage},
     editor::Editor,
-    rpc::Client,
+    rpc::{self, Client},
     theme,
     util::{ResultExt, TryFutureExt},
     Settings,
@@ -14,10 +14,10 @@ use gpui::{
     keymap::Binding,
     platform::CursorStyle,
     views::{ItemType, Select, SelectStyle},
-    AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, View,
+    AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, Task, View,
     ViewContext, ViewHandle,
 };
-use postage::watch;
+use postage::{prelude::Stream, watch};
 use time::{OffsetDateTime, UtcOffset};
 
 const MESSAGE_LOADING_THRESHOLD: usize = 50;
@@ -31,6 +31,7 @@ pub struct ChatPanel {
     channel_select: ViewHandle<Select>,
     settings: watch::Receiver<Settings>,
     local_timezone: UtcOffset,
+    _observe_status: Task<()>,
 }
 
 pub enum Event {}
@@ -98,6 +99,14 @@ impl ChatPanel {
                 cx.dispatch_action(LoadMoreMessages);
             }
         });
+        let _observe_status = cx.spawn(|this, mut cx| {
+            let mut status = rpc.status();
+            async move {
+                while let Some(_) = status.recv().await {
+                    this.update(&mut cx, |_, cx| cx.notify());
+                }
+            }
+        });
 
         let mut this = Self {
             rpc,
@@ -108,6 +117,7 @@ impl ChatPanel {
             channel_select,
             settings,
             local_timezone: cx.platform().local_timezone(),
+            _observe_status,
         };
 
         this.init_active_channel(cx);
@@ -153,6 +163,7 @@ impl ChatPanel {
         if let Some(active_channel) = active_channel {
             self.set_active_channel(active_channel, cx);
         } else {
+            self.message_list.reset(0);
             self.active_channel = None;
         }
 
@@ -183,7 +194,7 @@ impl ChatPanel {
         cx: &mut ViewContext<Self>,
     ) {
         match event {
-            ChannelEvent::MessagesAdded {
+            ChannelEvent::MessagesUpdated {
                 old_range,
                 new_count,
             } => {
@@ -357,10 +368,6 @@ impl ChatPanel {
             })
         }
     }
-
-    fn is_signed_in(&self) -> bool {
-        self.rpc.user_id().borrow().is_some()
-    }
 }
 
 impl Entity for ChatPanel {
@@ -374,10 +381,9 @@ impl View for ChatPanel {
 
     fn render(&mut self, cx: &mut RenderContext<Self>) -> ElementBox {
         let theme = &self.settings.borrow().theme;
-        let element = if self.is_signed_in() {
-            self.render_channel()
-        } else {
-            self.render_sign_in_prompt(cx)
+        let element = match *self.rpc.status().borrow() {
+            rpc::Status::Connected { .. } => self.render_channel(),
+            _ => self.render_sign_in_prompt(cx),
         };
         ConstrainedBox::new(
             Container::new(element)
@@ -389,7 +395,7 @@ impl View for ChatPanel {
     }
 
     fn on_focus(&mut self, cx: &mut ViewContext<Self>) {
-        if self.is_signed_in() {
+        if matches!(*self.rpc.status().borrow(), rpc::Status::Connected { .. }) {
             cx.focus(&self.input_editor);
         }
     }

zed/src/editor/buffer.rs 🔗

@@ -2695,14 +2695,7 @@ impl<'a> Into<proto::operation::Edit> for &'a EditOperation {
 impl<'a> Into<proto::Anchor> for &'a Anchor {
     fn into(self) -> proto::Anchor {
         proto::Anchor {
-            version: self
-                .version
-                .iter()
-                .map(|entry| proto::VectorClockEntry {
-                    replica_id: entry.replica_id as u32,
-                    timestamp: entry.value,
-                })
-                .collect(),
+            version: (&self.version).into(),
             offset: self.offset as u64,
             bias: match self.bias {
                 Bias::Left => proto::anchor::Bias::Left as i32,

zed/src/rpc.rs 🔗

@@ -1,24 +1,24 @@
 use crate::util::ResultExt;
 use anyhow::{anyhow, Context, Result};
 use async_tungstenite::tungstenite::http::Request;
-use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
 use gpui::{AsyncAppContext, Entity, ModelContext, Task};
 use lazy_static::lazy_static;
 use parking_lot::RwLock;
-use postage::prelude::Stream;
-use postage::sink::Sink;
-use postage::watch;
-use std::any::TypeId;
-use std::collections::HashMap;
-use std::sync::Weak;
-use std::time::{Duration, Instant};
-use std::{convert::TryFrom, future::Future, sync::Arc};
+use postage::{prelude::Stream, watch};
+use rand::prelude::*;
+use std::{
+    any::TypeId,
+    collections::HashMap,
+    convert::TryFrom,
+    future::Future,
+    sync::{Arc, Weak},
+    time::{Duration, Instant},
+};
 use surf::Url;
-use zrpc::proto::{AnyTypedEnvelope, EntityMessage};
 pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
 use zrpc::{
-    proto::{EnvelopedMessage, RequestMessage},
-    Peer, Receipt,
+    proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
+    Conn, Peer, Receipt,
 };
 
 lazy_static! {
@@ -29,25 +29,55 @@ lazy_static! {
 pub struct Client {
     peer: Arc<Peer>,
     state: RwLock<ClientState>,
+    auth_callback: Option<
+        Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>>,
+    >,
+    connect_callback: Option<
+        Box<dyn 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>>,
+    >,
+}
+
+#[derive(Copy, Clone, Debug)]
+pub enum Status {
+    Disconnected,
+    Authenticating,
+    Connecting {
+        user_id: u64,
+    },
+    ConnectionError,
+    Connected {
+        connection_id: ConnectionId,
+        user_id: u64,
+    },
+    ConnectionLost,
+    Reauthenticating,
+    Reconnecting {
+        user_id: u64,
+    },
+    ReconnectionError {
+        next_reconnection: Instant,
+    },
 }
 
 struct ClientState {
-    connection_id: Option<ConnectionId>,
-    user_id: (watch::Sender<Option<u64>>, watch::Receiver<Option<u64>>),
+    status: (watch::Sender<Status>, watch::Receiver<Status>),
     entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
     model_handlers: HashMap<
         (TypeId, u64),
         Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
     >,
+    _maintain_connection: Option<Task<()>>,
+    heartbeat_interval: Duration,
 }
 
 impl Default for ClientState {
     fn default() -> Self {
         Self {
-            connection_id: Default::default(),
-            user_id: watch::channel(),
+            status: watch::channel_with(Status::Disconnected),
             entity_id_extractors: Default::default(),
             model_handlers: Default::default(),
+            _maintain_connection: None,
+            heartbeat_interval: Duration::from_secs(5),
         }
     }
 }
@@ -77,11 +107,71 @@ impl Client {
         Arc::new(Self {
             peer: Peer::new(),
             state: Default::default(),
+            auth_callback: None,
+            connect_callback: None,
         })
     }
 
-    pub fn user_id(&self) -> watch::Receiver<Option<u64>> {
-        self.state.read().user_id.1.clone()
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn set_login_and_connect_callbacks<Login, Connect>(
+        &mut self,
+        login: Login,
+        connect: Connect,
+    ) where
+        Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>,
+        Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>,
+    {
+        self.auth_callback = Some(Box::new(login));
+        self.connect_callback = Some(Box::new(connect));
+    }
+
+    pub fn status(&self) -> watch::Receiver<Status> {
+        self.state.read().status.1.clone()
+    }
+
+    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
+        let mut state = self.state.write();
+        *state.status.0.borrow_mut() = status;
+
+        match status {
+            Status::Connected { .. } => {
+                let heartbeat_interval = state.heartbeat_interval;
+                let this = self.clone();
+                let foreground = cx.foreground();
+                state._maintain_connection = Some(cx.foreground().spawn(async move {
+                    loop {
+                        foreground.timer(heartbeat_interval).await;
+                        this.request(proto::Ping {}).await.unwrap();
+                    }
+                }));
+            }
+            Status::ConnectionLost => {
+                let this = self.clone();
+                let foreground = cx.foreground();
+                let heartbeat_interval = state.heartbeat_interval;
+                state._maintain_connection = Some(cx.spawn(|cx| async move {
+                    let mut rng = StdRng::from_entropy();
+                    let mut delay = Duration::from_millis(100);
+                    while let Err(error) = this.authenticate_and_connect(&cx).await {
+                        log::error!("failed to connect {}", error);
+                        this.set_status(
+                            Status::ReconnectionError {
+                                next_reconnection: Instant::now() + delay,
+                            },
+                            &cx,
+                        );
+                        foreground.timer(delay).await;
+                        delay = delay
+                            .mul_f32(rng.gen_range(1.0..=2.0))
+                            .min(heartbeat_interval);
+                    }
+                }));
+            }
+            Status::Disconnected => {
+                state._maintain_connection.take();
+            }
+            _ => {}
+        }
     }
 
     pub fn subscribe_from_model<T, M, F>(
@@ -141,56 +231,57 @@ impl Client {
         self: &Arc<Self>,
         cx: &AsyncAppContext,
     ) -> anyhow::Result<()> {
-        if self.state.read().connection_id.is_some() {
-            return Ok(());
-        }
-
-        let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
-        let user_id = user_id.parse::<u64>()?;
-        let request =
-            Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
+        let was_disconnected = match *self.status().borrow() {
+            Status::Disconnected => true,
+            Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
+                false
+            }
+            Status::Connected { .. }
+            | Status::Connecting { .. }
+            | Status::Reconnecting { .. }
+            | Status::Authenticating
+            | Status::Reauthenticating => return Ok(()),
+        };
 
-        if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
-            let stream = smol::net::TcpStream::connect(host).await?;
-            let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
-            let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
-                .await
-                .context("websocket handshake")?;
-            self.add_connection(user_id, stream, cx).await?;
-        } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
-            let stream = smol::net::TcpStream::connect(host).await?;
-            let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
-            let (stream, _) = async_tungstenite::client_async(request, stream)
-                .await
-                .context("websocket handshake")?;
-            self.add_connection(user_id, stream, cx).await?;
+        if was_disconnected {
+            self.set_status(Status::Authenticating, cx);
         } else {
-            return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
+            self.set_status(Status::Reauthenticating, cx)
+        }
+
+        let (user_id, access_token) = match self.authenticate(&cx).await {
+            Ok(result) => result,
+            Err(err) => {
+                self.set_status(Status::ConnectionError, cx);
+                return Err(err);
+            }
         };
 
-        log::info!("connected to rpc address {}", *ZED_SERVER_URL);
-        Ok(())
+        if was_disconnected {
+            self.set_status(Status::Connecting { user_id }, cx);
+        } else {
+            self.set_status(Status::Reconnecting { user_id }, cx);
+        }
+        match self.connect(user_id, &access_token, cx).await {
+            Ok(conn) => {
+                log::info!("connected to rpc address {}", *ZED_SERVER_URL);
+                self.set_connection(user_id, conn, cx).await;
+                Ok(())
+            }
+            Err(err) => {
+                self.set_status(Status::ConnectionError, cx);
+                Err(err)
+            }
+        }
     }
 
-    pub async fn add_connection<Conn>(
-        self: &Arc<Self>,
-        user_id: u64,
-        conn: Conn,
-        cx: &AsyncAppContext,
-    ) -> anyhow::Result<()>
-    where
-        Conn: 'static
-            + futures::Sink<WebSocketMessage, Error = WebSocketError>
-            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
-            + Unpin
-            + Send,
-    {
+    async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
         let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
-        {
-            let mut cx = cx.clone();
-            let this = self.clone();
-            cx.foreground()
-                .spawn(async move {
+        cx.foreground()
+            .spawn({
+                let mut cx = cx.clone();
+                let this = self.clone();
+                async move {
                     while let Some(message) = incoming.recv().await {
                         let mut state = this.state.write();
                         if let Some(extract_entity_id) =
@@ -215,27 +306,90 @@ impl Client {
                             log::info!("unhandled message {}", message.payload_type_name());
                         }
                     }
-                })
-                .detach();
-        }
-        cx.background()
+                }
+            })
+            .detach();
+
+        self.set_status(
+            Status::Connected {
+                connection_id,
+                user_id,
+            },
+            cx,
+        );
+
+        let handle_io = cx.background().spawn(handle_io);
+        let this = self.clone();
+        let cx = cx.clone();
+        cx.foreground()
             .spawn(async move {
-                if let Err(error) = handle_io.await {
-                    log::error!("connection error: {:?}", error);
+                match handle_io.await {
+                    Ok(()) => this.set_status(Status::Disconnected, &cx),
+                    Err(err) => {
+                        log::error!("connection error: {:?}", err);
+                        this.set_status(Status::ConnectionLost, &cx);
+                    }
                 }
             })
             .detach();
-        let mut state = self.state.write();
-        state.connection_id = Some(connection_id);
-        state.user_id.0.send(Some(user_id)).await?;
-        Ok(())
     }
 
-    pub fn login(
-        platform: Arc<dyn gpui::Platform>,
-        executor: &Arc<gpui::executor::Background>,
-    ) -> Task<Result<(String, String)>> {
-        let executor = executor.clone();
+    fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
+        if let Some(callback) = self.auth_callback.as_ref() {
+            callback(cx)
+        } else {
+            self.authenticate_with_browser(cx)
+        }
+    }
+
+    fn connect(
+        self: &Arc<Self>,
+        user_id: u64,
+        access_token: &str,
+        cx: &AsyncAppContext,
+    ) -> Task<Result<Conn>> {
+        if let Some(callback) = self.connect_callback.as_ref() {
+            callback(user_id, access_token, cx)
+        } else {
+            self.connect_with_websocket(user_id, access_token, cx)
+        }
+    }
+
+    fn connect_with_websocket(
+        self: &Arc<Self>,
+        user_id: u64,
+        access_token: &str,
+        cx: &AsyncAppContext,
+    ) -> Task<Result<Conn>> {
+        let request =
+            Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
+        cx.background().spawn(async move {
+            if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
+                let stream = smol::net::TcpStream::connect(host).await?;
+                let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
+                let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
+                    .await
+                    .context("websocket handshake")?;
+                Ok(Conn::new(stream))
+            } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
+                let stream = smol::net::TcpStream::connect(host).await?;
+                let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
+                let (stream, _) = async_tungstenite::client_async(request, stream)
+                    .await
+                    .context("websocket handshake")?;
+                Ok(Conn::new(stream))
+            } else {
+                Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
+            }
+        })
+    }
+
+    pub fn authenticate_with_browser(
+        self: &Arc<Self>,
+        cx: &AsyncAppContext,
+    ) -> Task<Result<(u64, String)>> {
+        let platform = cx.platform();
+        let executor = cx.background();
         executor.clone().spawn(async move {
             if let Some((user_id, access_token)) = platform
                 .read_credentials(&ZED_SERVER_URL)
@@ -243,7 +397,7 @@ impl Client {
                 .flatten()
             {
                 log::info!("already signed in. user_id: {}", user_id);
-                return Ok((user_id, String::from_utf8(access_token).unwrap()));
+                return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap()));
             }
 
             // Generate a pair of asymmetric encryption keys. The public key will be used by the
@@ -309,21 +463,23 @@ impl Client {
             platform
                 .write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
                 .log_err();
-            Ok((user_id.to_string(), access_token))
+            Ok((user_id.parse()?, access_token))
         })
     }
 
-    pub async fn disconnect(&self) -> Result<()> {
+    pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
         let conn_id = self.connection_id()?;
         self.peer.disconnect(conn_id).await;
+        self.set_status(Status::Disconnected, cx);
         Ok(())
     }
 
     fn connection_id(&self) -> Result<ConnectionId> {
-        self.state
-            .read()
-            .connection_id
-            .ok_or_else(|| anyhow!("not connected"))
+        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
+            Ok(connection_id)
+        } else {
+            Err(anyhow!("not connected"))
+        }
     }
 
     pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
@@ -343,35 +499,6 @@ impl Client {
     }
 }
 
-pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
-    type Output: 'a + Future<Output = anyhow::Result<()>>;
-
-    fn handle(
-        &self,
-        message: TypedEnvelope<M>,
-        rpc: &'a Client,
-        cx: &'a mut gpui::AsyncAppContext,
-    ) -> Self::Output;
-}
-
-impl<'a, M, F, Fut> MessageHandler<'a, M> for F
-where
-    M: proto::EnvelopedMessage,
-    F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
-    Fut: 'a + Future<Output = anyhow::Result<()>>,
-{
-    type Output = Fut;
-
-    fn handle(
-        &self,
-        message: TypedEnvelope<M>,
-        rpc: &'a Client,
-        cx: &'a mut gpui::AsyncAppContext,
-    ) -> Self::Output {
-        (self)(message, rpc, cx)
-    }
-}
-
 const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
 
 pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
@@ -396,13 +523,62 @@ const LOGIN_RESPONSE: &'static str = "
 </html>
 ";
 
-#[test]
-fn test_encode_and_decode_worktree_url() {
-    let url = encode_worktree_url(5, "deadbeef");
-    assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
-    assert_eq!(
-        decode_worktree_url(&format!("\n {}\t", url)),
-        Some((5, "deadbeef".to_string()))
-    );
-    assert_eq!(decode_worktree_url("not://the-right-format"), None);
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::test::FakeServer;
+    use gpui::TestAppContext;
+
+    #[gpui::test(iterations = 10)]
+    async fn test_heartbeat(cx: TestAppContext) {
+        cx.foreground().forbid_parking();
+
+        let user_id = 5;
+        let mut client = Client::new();
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+        cx.foreground().advance_clock(Duration::from_secs(10));
+        let ping = server.receive::<proto::Ping>().await.unwrap();
+        server.respond(ping.receipt(), proto::Ack {}).await;
+
+        cx.foreground().advance_clock(Duration::from_secs(10));
+        let ping = server.receive::<proto::Ping>().await.unwrap();
+        server.respond(ping.receipt(), proto::Ack {}).await;
+
+        client.disconnect(&cx.to_async()).await.unwrap();
+        assert!(server.receive::<proto::Ping>().await.is_err());
+    }
+
+    #[gpui::test(iterations = 10)]
+    async fn test_reconnection(cx: TestAppContext) {
+        cx.foreground().forbid_parking();
+
+        let user_id = 5;
+        let mut client = Client::new();
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+        let mut status = client.status();
+        assert!(matches!(
+            status.recv().await,
+            Some(Status::Connected { .. })
+        ));
+
+        server.forbid_connections();
+        server.disconnect().await;
+        while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
+
+        server.allow_connections();
+        cx.foreground().advance_clock(Duration::from_secs(10));
+        while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
+    }
+
+    #[test]
+    fn test_encode_and_decode_worktree_url() {
+        let url = encode_worktree_url(5, "deadbeef");
+        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
+        assert_eq!(
+            decode_worktree_url(&format!("\n {}\t", url)),
+            Some((5, "deadbeef".to_string()))
+        );
+        assert_eq!(decode_worktree_url("not://the-right-format"), None);
+    }
 }

zed/src/test.rs 🔗

@@ -3,24 +3,27 @@ use crate::{
     channel::ChannelList,
     fs::RealFs,
     language::LanguageRegistry,
-    rpc,
+    rpc::{self, Client},
     settings::{self, ThemeRegistry},
     time::ReplicaId,
     user::UserStore,
     AppState,
 };
-use gpui::{Entity, ModelHandle, MutableAppContext};
+use anyhow::{anyhow, Result};
+use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
 use parking_lot::Mutex;
+use postage::{mpsc, prelude::Stream as _};
 use smol::channel;
 use std::{
     marker::PhantomData,
     path::{Path, PathBuf},
-    sync::Arc,
+    sync::{
+        atomic::{AtomicBool, Ordering::SeqCst},
+        Arc,
+    },
 };
 use tempdir::TempDir;
-
-#[cfg(feature = "test-support")]
-pub use zrpc::test::Channel;
+use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
 
 #[cfg(test)]
 #[ctor::ctor]
@@ -195,3 +198,117 @@ impl<T: Entity> Observer<T> {
         (observer, notify_rx)
     }
 }
+
+pub struct FakeServer {
+    peer: Arc<Peer>,
+    incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
+    connection_id: Mutex<Option<ConnectionId>>,
+    forbid_connections: AtomicBool,
+}
+
+impl FakeServer {
+    pub async fn for_client(
+        client_user_id: u64,
+        client: &mut Arc<Client>,
+        cx: &TestAppContext,
+    ) -> Arc<Self> {
+        let result = Arc::new(Self {
+            peer: Peer::new(),
+            incoming: Default::default(),
+            connection_id: Default::default(),
+            forbid_connections: Default::default(),
+        });
+
+        Arc::get_mut(client)
+            .unwrap()
+            .set_login_and_connect_callbacks(
+                move |cx| {
+                    cx.spawn(|_| async move {
+                        let access_token = "the-token".to_string();
+                        Ok((client_user_id, access_token))
+                    })
+                },
+                {
+                    let server = result.clone();
+                    move |user_id, access_token, cx| {
+                        assert_eq!(user_id, client_user_id);
+                        assert_eq!(access_token, "the-token");
+                        cx.spawn({
+                            let server = server.clone();
+                            move |cx| async move { server.connect(&cx).await }
+                        })
+                    }
+                },
+            );
+
+        client
+            .authenticate_and_connect(&cx.to_async())
+            .await
+            .unwrap();
+        result
+    }
+
+    pub async fn disconnect(&self) {
+        self.peer.disconnect(self.connection_id()).await;
+        self.connection_id.lock().take();
+        self.incoming.lock().take();
+    }
+
+    async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
+        if self.forbid_connections.load(SeqCst) {
+            Err(anyhow!("server is forbidding connections"))
+        } else {
+            let (client_conn, server_conn, _) = Conn::in_memory();
+            let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
+            cx.background().spawn(io).detach();
+            *self.incoming.lock() = Some(incoming);
+            *self.connection_id.lock() = Some(connection_id);
+            Ok(client_conn)
+        }
+    }
+
+    pub fn forbid_connections(&self) {
+        self.forbid_connections.store(true, SeqCst);
+    }
+
+    pub fn allow_connections(&self) {
+        self.forbid_connections.store(false, SeqCst);
+    }
+
+    pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
+        self.peer.send(self.connection_id(), message).await.unwrap();
+    }
+
+    pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
+        let message = self
+            .incoming
+            .lock()
+            .as_mut()
+            .expect("not connected")
+            .recv()
+            .await
+            .ok_or_else(|| anyhow!("other half hung up"))?;
+        let type_name = message.payload_type_name();
+        Ok(*message
+            .into_any()
+            .downcast::<TypedEnvelope<M>>()
+            .unwrap_or_else(|_| {
+                panic!(
+                    "fake server received unexpected message type: {:?}",
+                    type_name
+                );
+            }))
+    }
+
+    pub async fn respond<T: proto::RequestMessage>(
+        &self,
+        receipt: Receipt<T>,
+        response: T::Response,
+    ) {
+        self.peer.respond(receipt, response).await.unwrap()
+    }
+
+    fn connection_id(&self) -> ConnectionId {
+        self.connection_id.lock().expect("not connected")
+    }
+}

zed/src/worktree.rs 🔗

@@ -234,6 +234,7 @@ impl Worktree {
                         .into_iter()
                         .map(|p| (PeerId(p.peer_id), p.replica_id as ReplicaId))
                         .collect(),
+                    queued_operations: Default::default(),
                     languages,
                     _subscriptions,
                 })
@@ -656,6 +657,7 @@ pub struct LocalWorktree {
     shared_buffers: HashMap<PeerId, HashMap<u64, ModelHandle<Buffer>>>,
     peers: HashMap<PeerId, ReplicaId>,
     languages: Arc<LanguageRegistry>,
+    queued_operations: Vec<(u64, Operation)>,
     fs: Arc<dyn Fs>,
 }
 
@@ -711,6 +713,7 @@ impl LocalWorktree {
                 poll_task: None,
                 open_buffers: Default::default(),
                 shared_buffers: Default::default(),
+                queued_operations: Default::default(),
                 peers: Default::default(),
                 languages,
                 fs,
@@ -1091,6 +1094,7 @@ pub struct RemoteWorktree {
     open_buffers: HashMap<usize, RemoteBuffer>,
     peers: HashMap<PeerId, ReplicaId>,
     languages: Arc<LanguageRegistry>,
+    queued_operations: Vec<(u64, Operation)>,
     _subscriptions: Vec<rpc::Subscription>,
 }
 
@@ -1550,16 +1554,23 @@ impl File {
                     .map(|share| (share.rpc.clone(), share.remote_id)),
                 Worktree::Remote(worktree) => Some((worktree.rpc.clone(), worktree.remote_id)),
             } {
-                cx.spawn(|_, _| async move {
+                cx.spawn(|worktree, mut cx| async move {
                     if let Err(error) = rpc
-                        .send(proto::UpdateBuffer {
+                        .request(proto::UpdateBuffer {
                             worktree_id: remote_id,
                             buffer_id,
-                            operations: Some(operation).iter().map(Into::into).collect(),
+                            operations: vec![(&operation).into()],
                         })
                         .await
                     {
-                        log::error!("error sending buffer operation: {}", error);
+                        worktree.update(&mut cx, |worktree, _| {
+                            log::error!("error sending buffer operation: {}", error);
+                            match worktree {
+                                Worktree::Local(t) => &mut t.queued_operations,
+                                Worktree::Remote(t) => &mut t.queued_operations,
+                            }
+                            .push((buffer_id, operation));
+                        });
                     }
                 })
                 .detach();
@@ -1582,7 +1593,7 @@ impl File {
                             .await
                         {
                             log::error!("error closing remote buffer: {}", error);
-                        };
+                        }
                     })
                     .detach();
             }

zrpc/proto/zed.proto 🔗

@@ -6,9 +6,9 @@ message Envelope {
     optional uint32 responding_to = 2;
     optional uint32 original_sender_id = 3;
     oneof payload {
-        Error error = 4;
-        Ping ping = 5;
-        Pong pong = 6;
+        Ack ack = 4;
+        Error error = 5;
+        Ping ping = 6;
         ShareWorktree share_worktree = 7;
         ShareWorktreeResponse share_worktree_response = 8;
         OpenWorktree open_worktree = 9;
@@ -40,13 +40,9 @@ message Envelope {
 
 // Messages
 
-message Ping {
-    int32 id = 1;
-}
+message Ping {}
 
-message Pong {
-    int32 id = 2;
-}
+message Ack {}
 
 message Error {
     string message = 1;

zrpc/src/conn.rs 🔗

@@ -0,0 +1,101 @@
+use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
+use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
+use std::{io, task::Poll};
+
+pub struct Conn {
+    pub(crate) tx:
+        Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
+    pub(crate) rx: Box<
+        dyn 'static
+            + Send
+            + Unpin
+            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
+    >,
+}
+
+impl Conn {
+    pub fn new<S>(stream: S) -> Self
+    where
+        S: 'static
+            + Send
+            + Unpin
+            + futures::Sink<WebSocketMessage, Error = WebSocketError>
+            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
+    {
+        let (tx, rx) = stream.split();
+        Self {
+            tx: Box::new(tx),
+            rx: Box::new(rx),
+        }
+    }
+
+    pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
+        self.tx.send(message).await
+    }
+
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
+        let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
+        postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
+
+        let (a_tx, a_rx) = Self::channel(kill_rx.clone());
+        let (b_tx, b_rx) = Self::channel(kill_rx);
+        (
+            Self { tx: a_tx, rx: b_rx },
+            Self { tx: b_tx, rx: a_rx },
+            kill_tx,
+        )
+    }
+
+    #[cfg(any(test, feature = "test-support"))]
+    fn channel(
+        kill_rx: postage::watch::Receiver<Option<()>>,
+    ) -> (
+        Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
+        Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
+    ) {
+        use futures::{future, SinkExt as _};
+        use io::{Error, ErrorKind};
+
+        let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
+        let tx = tx
+            .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
+            .with({
+                let kill_rx = kill_rx.clone();
+                move |msg| {
+                    if kill_rx.borrow().is_none() {
+                        future::ready(Ok(msg))
+                    } else {
+                        future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
+                    }
+                }
+            });
+        let rx = KillableReceiver { kill_rx, rx };
+
+        (Box::new(tx), Box::new(rx))
+    }
+}
+
+struct KillableReceiver {
+    rx: mpsc::UnboundedReceiver<WebSocketMessage>,
+    kill_rx: postage::watch::Receiver<Option<()>>,
+}
+
+impl Stream for KillableReceiver {
+    type Item = Result<WebSocketMessage, WebSocketError>;
+
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
+            Poll::Ready(Some(Err(io::Error::new(
+                io::ErrorKind::Other,
+                "connection killed",
+            )
+            .into())))
+        } else {
+            self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
+        }
+    }
+}

zrpc/src/lib.rs 🔗

@@ -1,7 +1,6 @@
 pub mod auth;
+mod conn;
 mod peer;
 pub mod proto;
-#[cfg(any(test, feature = "test-support"))]
-pub mod test;
-
+pub use conn::Conn;
 pub use peer::*;

zrpc/src/peer.rs 🔗

@@ -1,8 +1,8 @@
-use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
+use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
+use super::Conn;
 use anyhow::{anyhow, Context, Result};
 use async_lock::{Mutex, RwLock};
-use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use futures::{FutureExt, StreamExt};
+use futures::FutureExt as _;
 use postage::{
     mpsc,
     prelude::{Sink as _, Stream as _},
@@ -98,21 +98,14 @@ impl Peer {
         })
     }
 
-    pub async fn add_connection<Conn>(
+    pub async fn add_connection(
         self: &Arc<Self>,
         conn: Conn,
     ) -> (
         ConnectionId,
         impl Future<Output = anyhow::Result<()>> + Send,
         mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
-    )
-    where
-        Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
-            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
-            + Send
-            + Unpin,
-    {
-        let (tx, rx) = conn.split();
+    ) {
         let connection_id = ConnectionId(
             self.next_connection_id
                 .fetch_add(1, atomic::Ordering::SeqCst),
@@ -124,9 +117,10 @@ impl Peer {
             next_message_id: Default::default(),
             response_channels: Default::default(),
         };
-        let mut writer = MessageStream::new(tx);
-        let mut reader = MessageStream::new(rx);
+        let mut writer = MessageStream::new(conn.tx);
+        let mut reader = MessageStream::new(conn.rx);
 
+        let this = self.clone();
         let response_channels = connection.response_channels.clone();
         let handle_io = async move {
             loop {
@@ -147,6 +141,7 @@ impl Peer {
                                     if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
                                         if incoming_tx.send(envelope).await.is_err() {
                                             response_channels.lock().await.clear();
+                                            this.connections.write().await.remove(&connection_id);
                                             return Ok(())
                                         }
                                     } else {
@@ -158,6 +153,7 @@ impl Peer {
                             }
                             Err(error) => {
                                 response_channels.lock().await.clear();
+                                this.connections.write().await.remove(&connection_id);
                                 Err(error).context("received invalid RPC message")?;
                             }
                         },
@@ -165,11 +161,13 @@ impl Peer {
                             Some(outgoing) => {
                                 if let Err(result) = writer.write_message(&outgoing).await {
                                     response_channels.lock().await.clear();
+                                    this.connections.write().await.remove(&connection_id);
                                     Err(result).context("failed to write RPC message")?;
                                 }
                             }
                             None => {
                                 response_channels.lock().await.clear();
+                                this.connections.write().await.remove(&connection_id);
                                 return Ok(())
                             }
                         }
@@ -342,7 +340,9 @@ impl Peer {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::{test, TypedEnvelope};
+    use crate::TypedEnvelope;
+    use async_tungstenite::tungstenite::Message as WebSocketMessage;
+    use futures::StreamExt as _;
 
     #[test]
     fn test_request_response() {
@@ -352,12 +352,12 @@ mod tests {
             let client1 = Peer::new();
             let client2 = Peer::new();
 
-            let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
+            let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory();
             let (client1_conn_id, io_task1, _) =
                 client1.add_connection(client1_to_server_conn).await;
             let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
 
-            let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
+            let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory();
             let (client2_conn_id, io_task3, _) =
                 client2.add_connection(client2_to_server_conn).await;
             let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
@@ -371,18 +371,18 @@ mod tests {
 
             assert_eq!(
                 client1
-                    .request(client1_conn_id, proto::Ping { id: 1 },)
+                    .request(client1_conn_id, proto::Ping {},)
                     .await
                     .unwrap(),
-                proto::Pong { id: 1 }
+                proto::Ack {}
             );
 
             assert_eq!(
                 client2
-                    .request(client2_conn_id, proto::Ping { id: 2 },)
+                    .request(client2_conn_id, proto::Ping {},)
                     .await
                     .unwrap(),
-                proto::Pong { id: 2 }
+                proto::Ack {}
             );
 
             assert_eq!(
@@ -438,13 +438,7 @@ mod tests {
                     let envelope = envelope.into_any();
                     if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
                         let receipt = envelope.receipt();
-                        peer.respond(
-                            receipt,
-                            proto::Pong {
-                                id: envelope.payload.id,
-                            },
-                        )
-                        .await?
+                        peer.respond(receipt, proto::Ack {}).await?
                     } else if let Some(envelope) =
                         envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
                     {
@@ -492,7 +486,7 @@ mod tests {
     #[test]
     fn test_disconnect() {
         smol::block_on(async move {
-            let (client_conn, mut server_conn) = test::Channel::bidirectional();
+            let (client_conn, mut server_conn, _) = Conn::in_memory();
 
             let client = Peer::new();
             let (connection_id, io_handler, mut incoming) =
@@ -516,18 +510,17 @@ mod tests {
 
             io_ended_rx.recv().await;
             messages_ended_rx.recv().await;
-            assert!(
-                futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
-                    .await
-                    .is_err()
-            );
+            assert!(server_conn
+                .send(WebSocketMessage::Binary(vec![]))
+                .await
+                .is_err());
         });
     }
 
     #[test]
     fn test_io_error() {
         smol::block_on(async move {
-            let (client_conn, server_conn) = test::Channel::bidirectional();
+            let (client_conn, server_conn, _) = Conn::in_memory();
             drop(server_conn);
 
             let client = Peer::new();
@@ -537,7 +530,7 @@ mod tests {
             smol::spawn(async move { incoming.next().await }).detach();
 
             let err = client
-                .request(connection_id, proto::Ping { id: 42 })
+                .request(connection_id, proto::Ping {})
                 .await
                 .unwrap_err();
             assert_eq!(err.to_string(), "connection was closed");

zrpc/src/proto.rs 🔗

@@ -120,6 +120,7 @@ macro_rules! entity_messages {
 }
 
 messages!(
+    Ack,
     AddPeer,
     BufferSaved,
     ChannelMessageSent,
@@ -140,7 +141,6 @@ messages!(
     OpenWorktree,
     OpenWorktreeResponse,
     Ping,
-    Pong,
     RemovePeer,
     SaveBuffer,
     SendChannelMessage,
@@ -157,8 +157,9 @@ request_messages!(
     (JoinChannel, JoinChannelResponse),
     (OpenBuffer, OpenBufferResponse),
     (OpenWorktree, OpenWorktreeResponse),
-    (Ping, Pong),
+    (Ping, Ack),
     (SaveBuffer, BufferSaved),
+    (UpdateBuffer, Ack),
     (ShareWorktree, ShareWorktreeResponse),
     (SendChannelMessage, SendChannelMessageResponse),
     (GetChannelMessages, GetChannelMessagesResponse),
@@ -247,30 +248,3 @@ impl From<SystemTime> for Timestamp {
         }
     }
 }
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use crate::test;
-
-    #[test]
-    fn test_round_trip_message() {
-        smol::block_on(async {
-            let stream = test::Channel::new();
-            let message1 = Ping { id: 5 }.into_envelope(3, None, None);
-            let message2 = OpenBuffer {
-                worktree_id: 0,
-                path: "some/path".to_string(),
-            }
-            .into_envelope(5, None, None);
-
-            let mut message_stream = MessageStream::new(stream);
-            message_stream.write_message(&message1).await.unwrap();
-            message_stream.write_message(&message2).await.unwrap();
-            let decoded_message1 = message_stream.read_message().await.unwrap();
-            let decoded_message2 = message_stream.read_message().await.unwrap();
-            assert_eq!(decoded_message1, message1);
-            assert_eq!(decoded_message2, message2);
-        });
-    }
-}

zrpc/src/test.rs 🔗

@@ -1,64 +0,0 @@
-use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use std::{
-    io,
-    pin::Pin,
-    task::{Context, Poll},
-};
-
-pub struct Channel {
-    tx: futures::channel::mpsc::UnboundedSender<WebSocketMessage>,
-    rx: futures::channel::mpsc::UnboundedReceiver<WebSocketMessage>,
-}
-
-impl Channel {
-    pub fn new() -> Self {
-        let (tx, rx) = futures::channel::mpsc::unbounded();
-        Self { tx, rx }
-    }
-
-    pub fn bidirectional() -> (Self, Self) {
-        let (a_tx, a_rx) = futures::channel::mpsc::unbounded();
-        let (b_tx, b_rx) = futures::channel::mpsc::unbounded();
-        let a = Self { tx: a_tx, rx: b_rx };
-        let b = Self { tx: b_tx, rx: a_rx };
-        (a, b)
-    }
-}
-
-impl futures::Sink<WebSocketMessage> for Channel {
-    type Error = WebSocketError;
-
-    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
-        Pin::new(&mut self.tx)
-            .poll_ready(cx)
-            .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
-    }
-
-    fn start_send(mut self: Pin<&mut Self>, item: WebSocketMessage) -> Result<(), Self::Error> {
-        Pin::new(&mut self.tx)
-            .start_send(item)
-            .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
-    }
-
-    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
-        Pin::new(&mut self.tx)
-            .poll_flush(cx)
-            .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
-    }
-
-    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
-        Pin::new(&mut self.tx)
-            .poll_close(cx)
-            .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
-    }
-}
-
-impl futures::Stream for Channel {
-    type Item = Result<WebSocketMessage, WebSocketError>;
-
-    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
-        Pin::new(&mut self.rx)
-            .poll_next(cx)
-            .map(|i| i.map(|i| Ok(i)))
-    }
-}