WIP

Max Brunsfeld created

Change summary

server/src/rpc.rs       | 712 +++++++++---------------------------------
server/src/rpc/store.rs | 574 ++++++++++++++++++++++++++++++++++
zed/src/worktree.rs     |  35 +
zrpc/proto/zed.proto    |   5 
zrpc/src/proto.rs       |   1 
5 files changed, 753 insertions(+), 574 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -1,3 +1,5 @@
+mod store;
+
 use super::{
     auth,
     db::{ChannelId, MessageId, UserId},
@@ -8,16 +10,17 @@ use anyhow::anyhow;
 use async_std::{sync::RwLock, task};
 use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
 use futures::{future::BoxFuture, FutureExt};
-use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
+use postage::{broadcast, mpsc, prelude::Sink as _, prelude::Stream as _};
 use sha1::{Digest as _, Sha1};
 use std::{
     any::TypeId,
-    collections::{hash_map, HashMap, HashSet},
+    collections::{HashMap, HashSet},
     future::Future,
     mem,
     sync::Arc,
     time::Instant,
 };
+use store::{ReplicaId, Store, Worktree};
 use surf::StatusCode;
 use tide::log;
 use tide::{
@@ -30,8 +33,6 @@ use zrpc::{
     Connection, ConnectionId, Peer, TypedEnvelope,
 };
 
-type ReplicaId = u16;
-
 type MessageHandler = Box<
     dyn Send
         + Sync
@@ -40,46 +41,12 @@ type MessageHandler = Box<
 
 pub struct Server {
     peer: Arc<Peer>,
-    state: RwLock<ServerState>,
+    store: RwLock<Store>,
     app_state: Arc<AppState>,
     handlers: HashMap<TypeId, MessageHandler>,
     notifications: Option<mpsc::Sender<()>>,
 }
 
-#[derive(Default)]
-struct ServerState {
-    connections: HashMap<ConnectionId, ConnectionState>,
-    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
-    worktrees: HashMap<u64, Worktree>,
-    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
-    channels: HashMap<ChannelId, Channel>,
-    next_worktree_id: u64,
-}
-
-struct ConnectionState {
-    user_id: UserId,
-    worktrees: HashSet<u64>,
-    channels: HashSet<ChannelId>,
-}
-
-struct Worktree {
-    host_connection_id: ConnectionId,
-    collaborator_user_ids: Vec<UserId>,
-    root_name: String,
-    share: Option<WorktreeShare>,
-}
-
-struct WorktreeShare {
-    guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
-    active_replica_ids: HashSet<ReplicaId>,
-    entries: HashMap<u64, proto::Entry>,
-}
-
-#[derive(Default)]
-struct Channel {
-    connection_ids: HashSet<ConnectionId>,
-}
-
 const MESSAGE_COUNT_PER_PAGE: usize = 100;
 const MAX_MESSAGE_LEN: usize = 1024;
 
@@ -92,7 +59,7 @@ impl Server {
         let mut server = Self {
             peer,
             app_state,
-            state: Default::default(),
+            store: Default::default(),
             handlers: Default::default(),
             notifications,
         };
@@ -100,7 +67,7 @@ impl Server {
         server
             .add_handler(Server::ping)
             .add_handler(Server::open_worktree)
-            .add_handler(Server::handle_close_worktree)
+            .add_handler(Server::close_worktree)
             .add_handler(Server::share_worktree)
             .add_handler(Server::unshare_worktree)
             .add_handler(Server::join_worktree)
@@ -149,7 +116,10 @@ impl Server {
         async move {
             let (connection_id, handle_io, mut incoming_rx) =
                 this.peer.add_connection(connection).await;
-            this.add_connection(connection_id, user_id).await;
+            this.store
+                .write()
+                .await
+                .add_connection(connection_id, user_id);
             if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
                 log::error!("error updating collaborators for {:?}: {}", user_id, err);
             }
@@ -197,61 +167,40 @@ impl Server {
         }
     }
 
-    async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
+    async fn sign_out(self: &Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
         self.peer.disconnect(connection_id).await;
-        self.remove_connection(connection_id).await?;
-        Ok(())
-    }
+        let removed_connection = self.store.write().await.remove_connection(connection_id)?;
 
-    // Add a new connection associated with a given user.
-    async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
-        let mut state = self.state.write().await;
-        state.connections.insert(
-            connection_id,
-            ConnectionState {
-                user_id,
-                worktrees: Default::default(),
-                channels: Default::default(),
-            },
-        );
-        state
-            .connections_by_user_id
-            .entry(user_id)
-            .or_default()
-            .insert(connection_id);
-    }
-
-    // Remove the given connection and its association with any worktrees.
-    async fn remove_connection(
-        self: &Arc<Server>,
-        connection_id: ConnectionId,
-    ) -> tide::Result<()> {
-        let mut worktree_ids = Vec::new();
-        let mut state = self.state.write().await;
-        if let Some(connection) = state.connections.remove(&connection_id) {
-            worktree_ids = connection.worktrees.into_iter().collect();
-
-            for channel_id in connection.channels {
-                if let Some(channel) = state.channels.get_mut(&channel_id) {
-                    channel.connection_ids.remove(&connection_id);
-                }
-            }
-
-            let user_connections = state
-                .connections_by_user_id
-                .get_mut(&connection.user_id)
-                .unwrap();
-            user_connections.remove(&connection_id);
-            if user_connections.is_empty() {
-                state.connections_by_user_id.remove(&connection.user_id);
+        for (worktree_id, worktree) in removed_connection.hosted_worktrees {
+            if let Some(share) = worktree.share {
+                broadcast(
+                    connection_id,
+                    share.guest_connection_ids.keys().copied().collect(),
+                    |conn_id| {
+                        self.peer
+                            .send(conn_id, proto::UnshareWorktree { worktree_id })
+                    },
+                )
+                .await?;
             }
         }
 
-        drop(state);
-        for worktree_id in worktree_ids {
-            self.close_worktree(worktree_id, connection_id).await?;
+        for (worktree_id, peer_ids) in removed_connection.guest_worktree_ids {
+            broadcast(connection_id, peer_ids, |conn_id| {
+                self.peer.send(
+                    conn_id,
+                    proto::RemovePeer {
+                        worktree_id,
+                        peer_id: connection_id.0,
+                    },
+                )
+            })
+            .await?;
         }
 
+        self.update_collaborators_for_users(removed_connection.collaborator_ids.iter())
+            .await;
+
         Ok(())
     }
 
@@ -266,7 +215,7 @@ impl Server {
     ) -> tide::Result<()> {
         let receipt = request.receipt();
         let host_user_id = self
-            .state
+            .store
             .read()
             .await
             .user_id_for_connection(request.sender_id)?;
@@ -289,7 +238,7 @@ impl Server {
         }
 
         let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
-        let worktree_id = self.state.write().await.add_worktree(Worktree {
+        let worktree_id = self.store.write().await.add_worktree(Worktree {
             host_connection_id: request.sender_id,
             collaborator_user_ids: collaborator_user_ids.clone(),
             root_name: request.payload.root_name,
@@ -305,6 +254,33 @@ impl Server {
         Ok(())
     }
 
+    async fn close_worktree(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::CloseWorktree>,
+    ) -> tide::Result<()> {
+        let worktree_id = request.payload.worktree_id;
+        let worktree = self
+            .store
+            .write()
+            .await
+            .remove_worktree(worktree_id, request.sender_id)?;
+
+        if let Some(share) = worktree.share {
+            broadcast(
+                request.sender_id,
+                share.guest_connection_ids.keys().copied().collect(),
+                |conn_id| {
+                    self.peer
+                        .send(conn_id, proto::UnshareWorktree { worktree_id })
+                },
+            )
+            .await?;
+        }
+        self.update_collaborators_for_users(&worktree.collaborator_user_ids)
+            .await?;
+        Ok(())
+    }
+
     async fn share_worktree(
         self: Arc<Server>,
         mut request: TypedEnvelope<proto::ShareWorktree>,
@@ -319,16 +295,12 @@ impl Server {
             .map(|entry| (entry.id, entry))
             .collect();
 
-        let mut state = self.state.write().await;
-        if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
-            worktree.share = Some(WorktreeShare {
-                guest_connection_ids: Default::default(),
-                active_replica_ids: Default::default(),
-                entries,
-            });
-            let collaborator_user_ids = worktree.collaborator_user_ids.clone();
-
-            drop(state);
+        if let Some(collaborator_user_ids) =
+            self.store
+                .write()
+                .await
+                .share_worktree(worktree.id, request.sender_id, entries)
+        {
             self.peer
                 .respond(request.receipt(), proto::ShareWorktreeResponse {})
                 .await?;
@@ -352,26 +324,11 @@ impl Server {
         request: TypedEnvelope<proto::UnshareWorktree>,
     ) -> tide::Result<()> {
         let worktree_id = request.payload.worktree_id;
-
-        let connection_ids;
-        let collaborator_user_ids;
-        {
-            let mut state = self.state.write().await;
-            let worktree = state.write_worktree(worktree_id, request.sender_id)?;
-            if worktree.host_connection_id != request.sender_id {
-                return Err(anyhow!("no such worktree"))?;
-            }
-
-            connection_ids = worktree.connection_ids();
-            collaborator_user_ids = worktree.collaborator_user_ids.clone();
-
-            worktree.share.take();
-            for connection_id in &connection_ids {
-                if let Some(connection) = state.connections.get_mut(connection_id) {
-                    connection.worktrees.remove(&worktree_id);
-                }
-            }
-        }
+        let (connection_ids, collaborator_user_ids) = self
+            .store
+            .write()
+            .await
+            .unshare_worktree(worktree_id, request.sender_id)?;
 
         broadcast(request.sender_id, connection_ids, |conn_id| {
             self.peer
@@ -390,7 +347,7 @@ impl Server {
     ) -> tide::Result<()> {
         let worktree_id = request.payload.worktree_id;
         let user_id = self
-            .state
+            .store
             .read()
             .await
             .user_id_for_connection(request.sender_id)?;
@@ -398,7 +355,7 @@ impl Server {
         let response;
         let connection_ids;
         let collaborator_user_ids;
-        let mut state = self.state.write().await;
+        let mut state = self.store.write().await;
         match state.join_worktree(request.sender_id, user_id, worktree_id) {
             Ok((peer_replica_id, worktree)) => {
                 let share = worktree.share()?;
@@ -462,48 +419,17 @@ impl Server {
         Ok(())
     }
 
-    async fn handle_close_worktree(
-        self: Arc<Server>,
-        request: TypedEnvelope<proto::CloseWorktree>,
-    ) -> tide::Result<()> {
-        self.close_worktree(request.payload.worktree_id, request.sender_id)
-            .await
-    }
-
-    async fn close_worktree(
+    async fn leave_worktree(
         self: &Arc<Server>,
         worktree_id: u64,
         sender_conn_id: ConnectionId,
     ) -> tide::Result<()> {
-        let connection_ids;
-        let collaborator_user_ids;
-        let mut is_host = false;
-        let mut is_guest = false;
+        if let Some((connection_ids, collaborator_ids)) = self
+            .store
+            .write()
+            .await
+            .leave_worktree(sender_conn_id, worktree_id)
         {
-            let mut state = self.state.write().await;
-            let worktree = state.write_worktree(worktree_id, sender_conn_id)?;
-            connection_ids = worktree.connection_ids();
-            collaborator_user_ids = worktree.collaborator_user_ids.clone();
-
-            if worktree.host_connection_id == sender_conn_id {
-                is_host = true;
-                state.remove_worktree(worktree_id);
-            } else {
-                let share = worktree.share_mut()?;
-                if let Some(replica_id) = share.guest_connection_ids.remove(&sender_conn_id) {
-                    is_guest = true;
-                    share.active_replica_ids.remove(&replica_id);
-                }
-            }
-        }
-
-        if is_host {
-            broadcast(sender_conn_id, connection_ids, |conn_id| {
-                self.peer
-                    .send(conn_id, proto::UnshareWorktree { worktree_id })
-            })
-            .await?;
-        } else if is_guest {
             broadcast(sender_conn_id, connection_ids, |conn_id| {
                 self.peer.send(
                     conn_id,
@@ -513,10 +439,10 @@ impl Server {
                     },
                 )
             })
-            .await?
-        }
-        self.update_collaborators_for_users(&collaborator_user_ids)
             .await?;
+            self.update_collaborators_for_users(&collaborator_ids)
+                .await?;
+        }
         Ok(())
     }
 
@@ -524,22 +450,19 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateWorktree>,
     ) -> tide::Result<()> {
-        {
-            let mut state = self.state.write().await;
-            let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
-            let share = worktree.share_mut()?;
-
-            for entry_id in &request.payload.removed_entries {
-                share.entries.remove(&entry_id);
-            }
-
-            for entry in &request.payload.updated_entries {
-                share.entries.insert(entry.id, entry.clone());
-            }
-        }
+        let connection_ids = self.store.write().await.update_worktree(
+            request.sender_id,
+            request.payload.worktree_id,
+            &request.payload.removed_entries,
+            &request.payload.updated_entries,
+        )?;
+
+        broadcast(request.sender_id, connection_ids, |connection_id| {
+            self.peer
+                .forward_send(request.sender_id, connection_id, request.payload.clone())
+        })
+        .await?;
 
-        self.broadcast_in_worktree(request.payload.worktree_id, &request)
-            .await?;
         Ok(())
     }
 
@@ -548,14 +471,11 @@ impl Server {
         request: TypedEnvelope<proto::OpenBuffer>,
     ) -> tide::Result<()> {
         let receipt = request.receipt();
-        let worktree_id = request.payload.worktree_id;
         let host_connection_id = self
-            .state
+            .store
             .read()
             .await
-            .read_worktree(worktree_id, request.sender_id)?
-            .host_connection_id;
-
+            .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
         let response = self
             .peer
             .forward_request(request.sender_id, host_connection_id, request.payload)
@@ -569,16 +489,13 @@ impl Server {
         request: TypedEnvelope<proto::CloseBuffer>,
     ) -> tide::Result<()> {
         let host_connection_id = self
-            .state
+            .store
             .read()
             .await
-            .read_worktree(request.payload.worktree_id, request.sender_id)?
-            .host_connection_id;
-
+            .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
         self.peer
             .forward_send(request.sender_id, host_connection_id, request.payload)
             .await?;
-
         Ok(())
     }
 
@@ -589,15 +506,11 @@ impl Server {
         let host;
         let guests;
         {
-            let state = self.state.read().await;
-            let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
-            host = worktree.host_connection_id;
-            guests = worktree
-                .share()?
-                .guest_connection_ids
-                .keys()
-                .copied()
-                .collect::<Vec<_>>();
+            let state = self.store.read().await;
+            host = state
+                .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
+            guests = state
+                .worktree_guest_connection_ids(request.sender_id, request.payload.worktree_id)?;
         }
 
         let sender = request.sender_id;
@@ -627,8 +540,18 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::UpdateBuffer>,
     ) -> tide::Result<()> {
-        self.broadcast_in_worktree(request.payload.worktree_id, &request)
-            .await?;
+        broadcast(
+            request.sender_id,
+            self.store
+                .read()
+                .await
+                .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
+            |connection_id| {
+                self.peer
+                    .forward_send(request.sender_id, connection_id, request.payload.clone())
+            },
+        )
+        .await?;
         self.peer.respond(request.receipt(), proto::Ack {}).await?;
         Ok(())
     }
@@ -637,8 +560,19 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::BufferSaved>,
     ) -> tide::Result<()> {
-        self.broadcast_in_worktree(request.payload.worktree_id, &request)
-            .await
+        broadcast(
+            request.sender_id,
+            self.store
+                .read()
+                .await
+                .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
+            |connection_id| {
+                self.peer
+                    .forward_send(request.sender_id, connection_id, request.payload.clone())
+            },
+        )
+        .await?;
+        Ok(())
     }
 
     async fn get_channels(
@@ -646,7 +580,7 @@ impl Server {
         request: TypedEnvelope<proto::GetChannels>,
     ) -> tide::Result<()> {
         let user_id = self
-            .state
+            .store
             .read()
             .await
             .user_id_for_connection(request.sender_id)?;
@@ -698,45 +632,10 @@ impl Server {
     ) -> tide::Result<()> {
         let mut send_futures = Vec::new();
 
-        let state = self.state.read().await;
+        let state = self.store.read().await;
         for user_id in user_ids {
-            let mut collaborators = HashMap::new();
-            for worktree_id in state
-                .visible_worktrees_by_user_id
-                .get(&user_id)
-                .unwrap_or(&HashSet::new())
-            {
-                let worktree = &state.worktrees[worktree_id];
-
-                let mut guests = HashSet::new();
-                if let Ok(share) = worktree.share() {
-                    for guest_connection_id in share.guest_connection_ids.keys() {
-                        let user_id = state
-                            .user_id_for_connection(*guest_connection_id)
-                            .context("stale worktree guest connection")?;
-                        guests.insert(user_id.to_proto());
-                    }
-                }
-
-                let host_user_id = state
-                    .user_id_for_connection(worktree.host_connection_id)
-                    .context("stale worktree host connection")?;
-                let host =
-                    collaborators
-                        .entry(host_user_id)
-                        .or_insert_with(|| proto::Collaborator {
-                            user_id: host_user_id.to_proto(),
-                            worktrees: Vec::new(),
-                        });
-                host.worktrees.push(proto::WorktreeMetadata {
-                    root_name: worktree.root_name.clone(),
-                    is_shared: worktree.share().is_ok(),
-                    participants: guests.into_iter().collect(),
-                });
-            }
-
-            let collaborators = collaborators.into_values().collect::<Vec<_>>();
-            for connection_id in state.user_connection_ids(*user_id) {
+            let collaborators = state.collaborators_for_user(*user_id);
+            for connection_id in state.connection_ids_for_user(*user_id) {
                 send_futures.push(self.peer.send(
                     connection_id,
                     proto::UpdateCollaborators {
@@ -757,7 +656,7 @@ impl Server {
         request: TypedEnvelope<proto::JoinChannel>,
     ) -> tide::Result<()> {
         let user_id = self
-            .state
+            .store
             .read()
             .await
             .user_id_for_connection(request.sender_id)?;
@@ -771,7 +670,7 @@ impl Server {
             Err(anyhow!("access denied"))?;
         }
 
-        self.state
+        self.store
             .write()
             .await
             .join_channel(request.sender_id, channel_id);
@@ -806,7 +705,7 @@ impl Server {
         request: TypedEnvelope<proto::LeaveChannel>,
     ) -> tide::Result<()> {
         let user_id = self
-            .state
+            .store
             .read()
             .await
             .user_id_for_connection(request.sender_id)?;
@@ -820,7 +719,7 @@ impl Server {
             Err(anyhow!("access denied"))?;
         }
 
-        self.state
+        self.store
             .write()
             .await
             .leave_channel(request.sender_id, channel_id);
@@ -837,10 +736,10 @@ impl Server {
         let user_id;
         let connection_ids;
         {
-            let state = self.state.read().await;
+            let state = self.store.read().await;
             user_id = state.user_id_for_connection(request.sender_id)?;
-            if let Some(channel) = state.channels.get(&channel_id) {
-                connection_ids = channel.connection_ids();
+            if let Some(ids) = state.channel_connection_ids(channel_id) {
+                connection_ids = ids;
             } else {
                 return Ok(());
             }
@@ -925,7 +824,7 @@ impl Server {
         request: TypedEnvelope<proto::GetChannelMessages>,
     ) -> tide::Result<()> {
         let user_id = self
-            .state
+            .store
             .read()
             .await
             .user_id_for_connection(request.sender_id)?;
@@ -968,27 +867,6 @@ impl Server {
             .await?;
         Ok(())
     }
-
-    async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
-        &self,
-        worktree_id: u64,
-        message: &TypedEnvelope<T>,
-    ) -> tide::Result<()> {
-        let connection_ids = self
-            .state
-            .read()
-            .await
-            .read_worktree(worktree_id, message.sender_id)?
-            .connection_ids();
-
-        broadcast(message.sender_id, connection_ids, |conn_id| {
-            self.peer
-                .forward_send(message.sender_id, conn_id, message.payload.clone())
-        })
-        .await?;
-
-        Ok(())
-    }
 }
 
 pub async fn broadcast<F, T>(
@@ -1008,292 +886,6 @@ where
     Ok(())
 }
 
-impl ServerState {
-    fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
-        if let Some(connection) = self.connections.get_mut(&connection_id) {
-            connection.channels.insert(channel_id);
-            self.channels
-                .entry(channel_id)
-                .or_default()
-                .connection_ids
-                .insert(connection_id);
-        }
-    }
-
-    fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
-        if let Some(connection) = self.connections.get_mut(&connection_id) {
-            connection.channels.remove(&channel_id);
-            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
-                entry.get_mut().connection_ids.remove(&connection_id);
-                if entry.get_mut().connection_ids.is_empty() {
-                    entry.remove();
-                }
-            }
-        }
-    }
-
-    fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
-        Ok(self
-            .connections
-            .get(&connection_id)
-            .ok_or_else(|| anyhow!("unknown connection"))?
-            .user_id)
-    }
-
-    fn user_connection_ids<'a>(
-        &'a self,
-        user_id: UserId,
-    ) -> impl 'a + Iterator<Item = ConnectionId> {
-        self.connections_by_user_id
-            .get(&user_id)
-            .into_iter()
-            .flatten()
-            .copied()
-    }
-
-    // Add the given connection as a guest of the given worktree
-    fn join_worktree(
-        &mut self,
-        connection_id: ConnectionId,
-        user_id: UserId,
-        worktree_id: u64,
-    ) -> tide::Result<(ReplicaId, &Worktree)> {
-        let connection = self
-            .connections
-            .get_mut(&connection_id)
-            .ok_or_else(|| anyhow!("no such connection"))?;
-        let worktree = self
-            .worktrees
-            .get_mut(&worktree_id)
-            .ok_or_else(|| anyhow!("no such worktree"))?;
-        if !worktree.collaborator_user_ids.contains(&user_id) {
-            Err(anyhow!("no such worktree"))?;
-        }
-
-        let share = worktree.share_mut()?;
-        connection.worktrees.insert(worktree_id);
-
-        let mut replica_id = 1;
-        while share.active_replica_ids.contains(&replica_id) {
-            replica_id += 1;
-        }
-        share.active_replica_ids.insert(replica_id);
-        share.guest_connection_ids.insert(connection_id, replica_id);
-        return Ok((replica_id, worktree));
-    }
-
-    fn read_worktree(
-        &self,
-        worktree_id: u64,
-        connection_id: ConnectionId,
-    ) -> tide::Result<&Worktree> {
-        let worktree = self
-            .worktrees
-            .get(&worktree_id)
-            .ok_or_else(|| anyhow!("worktree not found"))?;
-
-        if worktree.host_connection_id == connection_id
-            || worktree
-                .share()?
-                .guest_connection_ids
-                .contains_key(&connection_id)
-        {
-            Ok(worktree)
-        } else {
-            Err(anyhow!(
-                "{} is not a member of worktree {}",
-                connection_id,
-                worktree_id
-            ))?
-        }
-    }
-
-    fn write_worktree(
-        &mut self,
-        worktree_id: u64,
-        connection_id: ConnectionId,
-    ) -> tide::Result<&mut Worktree> {
-        let worktree = self
-            .worktrees
-            .get_mut(&worktree_id)
-            .ok_or_else(|| anyhow!("worktree not found"))?;
-
-        if worktree.host_connection_id == connection_id
-            || worktree.share.as_ref().map_or(false, |share| {
-                share.guest_connection_ids.contains_key(&connection_id)
-            })
-        {
-            Ok(worktree)
-        } else {
-            Err(anyhow!(
-                "{} is not a member of worktree {}",
-                connection_id,
-                worktree_id
-            ))?
-        }
-    }
-
-    fn add_worktree(&mut self, worktree: Worktree) -> u64 {
-        let worktree_id = self.next_worktree_id;
-        for collaborator_user_id in &worktree.collaborator_user_ids {
-            self.visible_worktrees_by_user_id
-                .entry(*collaborator_user_id)
-                .or_default()
-                .insert(worktree_id);
-        }
-        self.next_worktree_id += 1;
-        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
-            connection.worktrees.insert(worktree_id);
-        }
-        self.worktrees.insert(worktree_id, worktree);
-
-        #[cfg(test)]
-        self.check_invariants();
-
-        worktree_id
-    }
-
-    fn remove_worktree(&mut self, worktree_id: u64) {
-        let worktree = self.worktrees.remove(&worktree_id).unwrap();
-        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
-            connection.worktrees.remove(&worktree_id);
-        }
-        if let Some(share) = worktree.share {
-            for connection_id in share.guest_connection_ids.keys() {
-                if let Some(connection) = self.connections.get_mut(connection_id) {
-                    connection.worktrees.remove(&worktree_id);
-                }
-            }
-        }
-        for collaborator_user_id in worktree.collaborator_user_ids {
-            if let Some(visible_worktrees) = self
-                .visible_worktrees_by_user_id
-                .get_mut(&collaborator_user_id)
-            {
-                visible_worktrees.remove(&worktree_id);
-            }
-        }
-
-        #[cfg(test)]
-        self.check_invariants();
-    }
-
-    #[cfg(test)]
-    fn check_invariants(&self) {
-        for (connection_id, connection) in &self.connections {
-            for worktree_id in &connection.worktrees {
-                let worktree = &self.worktrees.get(&worktree_id).unwrap();
-                if worktree.host_connection_id != *connection_id {
-                    assert!(worktree
-                        .share()
-                        .unwrap()
-                        .guest_connection_ids
-                        .contains_key(connection_id));
-                }
-            }
-            for channel_id in &connection.channels {
-                let channel = self.channels.get(channel_id).unwrap();
-                assert!(channel.connection_ids.contains(connection_id));
-            }
-            assert!(self
-                .connections_by_user_id
-                .get(&connection.user_id)
-                .unwrap()
-                .contains(connection_id));
-        }
-
-        for (user_id, connection_ids) in &self.connections_by_user_id {
-            for connection_id in connection_ids {
-                assert_eq!(
-                    self.connections.get(connection_id).unwrap().user_id,
-                    *user_id
-                );
-            }
-        }
-
-        for (worktree_id, worktree) in &self.worktrees {
-            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
-            assert!(host_connection.worktrees.contains(worktree_id));
-
-            for collaborator_id in &worktree.collaborator_user_ids {
-                let visible_worktree_ids = self
-                    .visible_worktrees_by_user_id
-                    .get(collaborator_id)
-                    .unwrap();
-                assert!(visible_worktree_ids.contains(worktree_id));
-            }
-
-            if let Some(share) = &worktree.share {
-                for guest_connection_id in share.guest_connection_ids.keys() {
-                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
-                    assert!(guest_connection.worktrees.contains(worktree_id));
-                }
-                assert_eq!(
-                    share.active_replica_ids.len(),
-                    share.guest_connection_ids.len(),
-                );
-                assert_eq!(
-                    share.active_replica_ids,
-                    share
-                        .guest_connection_ids
-                        .values()
-                        .copied()
-                        .collect::<HashSet<_>>(),
-                );
-            }
-        }
-
-        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
-            for worktree_id in visible_worktree_ids {
-                let worktree = self.worktrees.get(worktree_id).unwrap();
-                assert!(worktree.collaborator_user_ids.contains(user_id));
-            }
-        }
-
-        for (channel_id, channel) in &self.channels {
-            for connection_id in &channel.connection_ids {
-                let connection = self.connections.get(connection_id).unwrap();
-                assert!(connection.channels.contains(channel_id));
-            }
-        }
-    }
-}
-
-impl Worktree {
-    pub fn connection_ids(&self) -> Vec<ConnectionId> {
-        if let Some(share) = &self.share {
-            share
-                .guest_connection_ids
-                .keys()
-                .copied()
-                .chain(Some(self.host_connection_id))
-                .collect()
-        } else {
-            vec![self.host_connection_id]
-        }
-    }
-
-    fn share(&self) -> tide::Result<&WorktreeShare> {
-        Ok(self
-            .share
-            .as_ref()
-            .ok_or_else(|| anyhow!("worktree is not shared"))?)
-    }
-
-    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
-        Ok(self
-            .share
-            .as_mut()
-            .ok_or_else(|| anyhow!("worktree is not shared"))?)
-    }
-}
-
-impl Channel {
-    fn connection_ids(&self) -> Vec<ConnectionId> {
-        self.connection_ids.iter().copied().collect()
-    }
-}
-
 pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
     let server = Server::new(app.state().clone(), rpc.clone(), None);
     app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
@@ -2477,16 +2069,16 @@ mod tests {
             })
         }
 
-        async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
-            self.server.state.read().await
+        async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> {
+            self.server.store.read().await
         }
 
         async fn condition<F>(&mut self, mut predicate: F)
         where
-            F: FnMut(&ServerState) -> bool,
+            F: FnMut(&Store) -> bool,
         {
             async_std::future::timeout(Duration::from_millis(500), async {
-                while !(predicate)(&*self.server.state.read().await) {
+                while !(predicate)(&*self.server.store.read().await) {
                     self.notifications.recv().await;
                 }
             })

server/src/rpc/store.rs 🔗

@@ -0,0 +1,574 @@
+use crate::db::{ChannelId, MessageId, UserId};
+use crate::errors::TideResultExt;
+use anyhow::anyhow;
+use std::collections::{hash_map, HashMap, HashSet};
+use zrpc::{proto, ConnectionId};
+
+#[derive(Default)]
+pub struct Store {
+    connections: HashMap<ConnectionId, ConnectionState>,
+    connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
+    worktrees: HashMap<u64, Worktree>,
+    visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
+    channels: HashMap<ChannelId, Channel>,
+    next_worktree_id: u64,
+}
+
+struct ConnectionState {
+    user_id: UserId,
+    worktrees: HashSet<u64>,
+    channels: HashSet<ChannelId>,
+}
+
+pub struct Worktree {
+    pub host_connection_id: ConnectionId,
+    pub collaborator_user_ids: Vec<UserId>,
+    pub root_name: String,
+    pub share: Option<WorktreeShare>,
+}
+
+struct WorktreeShare {
+    pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
+    pub active_replica_ids: HashSet<ReplicaId>,
+    pub entries: HashMap<u64, proto::Entry>,
+}
+
+#[derive(Default)]
+struct Channel {
+    connection_ids: HashSet<ConnectionId>,
+}
+
+pub type ReplicaId = u16;
+
+#[derive(Default)]
+pub struct RemovedConnectionState {
+    pub hosted_worktrees: HashMap<u64, Worktree>,
+    pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
+    pub collaborator_ids: HashSet<UserId>,
+}
+
+impl Store {
+    pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
+        self.connections.insert(
+            connection_id,
+            ConnectionState {
+                user_id,
+                worktrees: Default::default(),
+                channels: Default::default(),
+            },
+        );
+        self.connections_by_user_id
+            .entry(user_id)
+            .or_default()
+            .insert(connection_id);
+    }
+
+    pub fn remove_connection(
+        &mut self,
+        connection_id: ConnectionId,
+    ) -> tide::Result<RemovedConnectionState> {
+        let connection = if let Some(connection) = self.connections.get(&connection_id) {
+            connection
+        } else {
+            return Err(anyhow!("no such connection"))?;
+        };
+
+        for channel_id in connection.channels {
+            if let Some(channel) = self.channels.get_mut(&channel_id) {
+                channel.connection_ids.remove(&connection_id);
+            }
+        }
+
+        let user_connections = self
+            .connections_by_user_id
+            .get_mut(&connection.user_id)
+            .unwrap();
+        user_connections.remove(&connection_id);
+        if user_connections.is_empty() {
+            self.connections_by_user_id.remove(&connection.user_id);
+        }
+
+        let mut result = RemovedConnectionState::default();
+        for worktree_id in connection.worktrees {
+            if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
+                result.hosted_worktrees.insert(worktree_id, worktree);
+                result
+                    .collaborator_ids
+                    .extend(worktree.collaborator_user_ids.iter().copied());
+            } else {
+                if let Some(worktree) = self.worktrees.get(&worktree_id) {
+                    result
+                        .guest_worktree_ids
+                        .insert(worktree_id, worktree.connection_ids());
+                    result
+                        .collaborator_ids
+                        .extend(worktree.collaborator_user_ids.iter().copied());
+                }
+            }
+        }
+
+        Ok(result)
+    }
+
+    pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
+        if let Some(connection) = self.connections.get_mut(&connection_id) {
+            connection.channels.insert(channel_id);
+            self.channels
+                .entry(channel_id)
+                .or_default()
+                .connection_ids
+                .insert(connection_id);
+        }
+    }
+
+    pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
+        if let Some(connection) = self.connections.get_mut(&connection_id) {
+            connection.channels.remove(&channel_id);
+            if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
+                entry.get_mut().connection_ids.remove(&connection_id);
+                if entry.get_mut().connection_ids.is_empty() {
+                    entry.remove();
+                }
+            }
+        }
+    }
+
+    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
+        Ok(self
+            .connections
+            .get(&connection_id)
+            .ok_or_else(|| anyhow!("unknown connection"))?
+            .user_id)
+    }
+
+    pub fn connection_ids_for_user<'a>(
+        &'a self,
+        user_id: UserId,
+    ) -> impl 'a + Iterator<Item = ConnectionId> {
+        self.connections_by_user_id
+            .get(&user_id)
+            .into_iter()
+            .flatten()
+            .copied()
+    }
+
+    pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
+        let mut collaborators = HashMap::new();
+        for worktree_id in self
+            .visible_worktrees_by_user_id
+            .get(&user_id)
+            .unwrap_or(&HashSet::new())
+        {
+            let worktree = &self.worktrees[worktree_id];
+
+            let mut guests = HashSet::new();
+            if let Ok(share) = worktree.share() {
+                for guest_connection_id in share.guest_connection_ids.keys() {
+                    if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
+                        guests.insert(user_id.to_proto());
+                    }
+                }
+            }
+
+            if let Ok(host_user_id) = self
+                .user_id_for_connection(worktree.host_connection_id)
+                .context("stale worktree host connection")
+            {
+                let host =
+                    collaborators
+                        .entry(host_user_id)
+                        .or_insert_with(|| proto::Collaborator {
+                            user_id: host_user_id.to_proto(),
+                            worktrees: Vec::new(),
+                        });
+                host.worktrees.push(proto::WorktreeMetadata {
+                    root_name: worktree.root_name.clone(),
+                    is_shared: worktree.share().is_ok(),
+                    participants: guests.into_iter().collect(),
+                });
+            }
+        }
+
+        collaborators.into_values().collect()
+    }
+
+    pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
+        let worktree_id = self.next_worktree_id;
+        for collaborator_user_id in &worktree.collaborator_user_ids {
+            self.visible_worktrees_by_user_id
+                .entry(*collaborator_user_id)
+                .or_default()
+                .insert(worktree_id);
+        }
+        self.next_worktree_id += 1;
+        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
+            connection.worktrees.insert(worktree_id);
+        }
+        self.worktrees.insert(worktree_id, worktree);
+
+        #[cfg(test)]
+        self.check_invariants();
+
+        worktree_id
+    }
+
+    pub fn remove_worktree(
+        &mut self,
+        worktree_id: u64,
+        acting_connection_id: ConnectionId,
+    ) -> tide::Result<Worktree> {
+        let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
+            if e.get().host_connection_id != acting_connection_id {
+                Err(anyhow!("not your worktree"))?;
+            }
+            e.remove()
+        } else {
+            return Err(anyhow!("no such worktree"))?;
+        };
+
+        if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
+            connection.worktrees.remove(&worktree_id);
+        }
+
+        if let Some(share) = worktree.share {
+            for connection_id in share.guest_connection_ids.keys() {
+                if let Some(connection) = self.connections.get_mut(connection_id) {
+                    connection.worktrees.remove(&worktree_id);
+                }
+            }
+        }
+
+        for collaborator_user_id in worktree.collaborator_user_ids {
+            if let Some(visible_worktrees) = self
+                .visible_worktrees_by_user_id
+                .get_mut(&collaborator_user_id)
+            {
+                visible_worktrees.remove(&worktree_id);
+            }
+        }
+
+        #[cfg(test)]
+        self.check_invariants();
+
+        Ok(worktree)
+    }
+
+    pub fn share_worktree(
+        &mut self,
+        worktree_id: u64,
+        connection_id: ConnectionId,
+        entries: HashMap<u64, proto::Entry>,
+    ) -> Option<Vec<UserId>> {
+        if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
+            if worktree.host_connection_id == connection_id {
+                worktree.share = Some(WorktreeShare {
+                    guest_connection_ids: Default::default(),
+                    active_replica_ids: Default::default(),
+                    entries,
+                });
+                return Some(worktree.collaborator_user_ids.clone());
+            }
+        }
+        None
+    }
+
+    pub fn unshare_worktree(
+        &mut self,
+        worktree_id: u64,
+        acting_connection_id: ConnectionId,
+    ) -> tide::Result<(Vec<ConnectionId>, Vec<UserId>)> {
+        let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
+            worktree
+        } else {
+            return Err(anyhow!("no such worktree"))?;
+        };
+
+        if worktree.host_connection_id != acting_connection_id {
+            return Err(anyhow!("not your worktree"))?;
+        }
+
+        let connection_ids = worktree.connection_ids();
+
+        if let Some(share) = worktree.share.take() {
+            for connection_id in &connection_ids {
+                if let Some(connection) = self.connections.get_mut(connection_id) {
+                    connection.worktrees.remove(&worktree_id);
+                }
+            }
+            Ok((connection_ids, worktree.collaborator_user_ids.clone()))
+        } else {
+            Err(anyhow!("worktree is not shared"))?
+        }
+    }
+
+    pub fn join_worktree(
+        &mut self,
+        connection_id: ConnectionId,
+        user_id: UserId,
+        worktree_id: u64,
+    ) -> tide::Result<(ReplicaId, &Worktree)> {
+        let connection = self
+            .connections
+            .get_mut(&connection_id)
+            .ok_or_else(|| anyhow!("no such connection"))?;
+        let worktree = self
+            .worktrees
+            .get_mut(&worktree_id)
+            .and_then(|worktree| {
+                if worktree.collaborator_user_ids.contains(&user_id) {
+                    Some(worktree)
+                } else {
+                    None
+                }
+            })
+            .ok_or_else(|| anyhow!("no such worktree"))?;
+
+        let share = worktree.share_mut()?;
+        connection.worktrees.insert(worktree_id);
+
+        let mut replica_id = 1;
+        while share.active_replica_ids.contains(&replica_id) {
+            replica_id += 1;
+        }
+        share.active_replica_ids.insert(replica_id);
+        share.guest_connection_ids.insert(connection_id, replica_id);
+        return Ok((replica_id, worktree));
+    }
+
+    pub fn leave_worktree(
+        &mut self,
+        connection_id: ConnectionId,
+        worktree_id: u64,
+    ) -> Option<(Vec<ConnectionId>, Vec<UserId>)> {
+        let worktree = self.worktrees.get_mut(&worktree_id)?;
+        let share = worktree.share.as_mut()?;
+        let replica_id = share.guest_connection_ids.remove(&connection_id)?;
+        share.active_replica_ids.remove(&replica_id);
+        Some((
+            worktree.connection_ids(),
+            worktree.collaborator_user_ids.clone(),
+        ))
+    }
+
+    pub fn update_worktree(
+        &mut self,
+        connection_id: ConnectionId,
+        worktree_id: u64,
+        removed_entries: &[u64],
+        updated_entries: &[proto::Entry],
+    ) -> tide::Result<Vec<ConnectionId>> {
+        let worktree = self.write_worktree(worktree_id, connection_id)?;
+        let share = worktree.share_mut()?;
+        for entry_id in removed_entries {
+            share.entries.remove(&entry_id);
+        }
+        for entry in updated_entries {
+            share.entries.insert(entry.id, entry.clone());
+        }
+        Ok(worktree.connection_ids())
+    }
+
+    pub fn worktree_host_connection_id(
+        &self,
+        connection_id: ConnectionId,
+        worktree_id: u64,
+    ) -> tide::Result<ConnectionId> {
+        Ok(self
+            .read_worktree(worktree_id, connection_id)?
+            .host_connection_id)
+    }
+
+    pub fn worktree_guest_connection_ids(
+        &self,
+        connection_id: ConnectionId,
+        worktree_id: u64,
+    ) -> tide::Result<Vec<ConnectionId>> {
+        Ok(self
+            .read_worktree(worktree_id, connection_id)?
+            .share()?
+            .guest_connection_ids
+            .keys()
+            .copied()
+            .collect())
+    }
+
+    pub fn worktree_connection_ids(
+        &self,
+        connection_id: ConnectionId,
+        worktree_id: u64,
+    ) -> tide::Result<Vec<ConnectionId>> {
+        Ok(self
+            .read_worktree(worktree_id, connection_id)?
+            .connection_ids())
+    }
+
+    pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
+        Some(self.channels.get(&channel_id)?.connection_ids())
+    }
+
+    fn read_worktree(
+        &self,
+        worktree_id: u64,
+        connection_id: ConnectionId,
+    ) -> tide::Result<&Worktree> {
+        let worktree = self
+            .worktrees
+            .get(&worktree_id)
+            .ok_or_else(|| anyhow!("worktree not found"))?;
+
+        if worktree.host_connection_id == connection_id
+            || worktree
+                .share()?
+                .guest_connection_ids
+                .contains_key(&connection_id)
+        {
+            Ok(worktree)
+        } else {
+            Err(anyhow!(
+                "{} is not a member of worktree {}",
+                connection_id,
+                worktree_id
+            ))?
+        }
+    }
+
+    fn write_worktree(
+        &mut self,
+        worktree_id: u64,
+        connection_id: ConnectionId,
+    ) -> tide::Result<&mut Worktree> {
+        let worktree = self
+            .worktrees
+            .get_mut(&worktree_id)
+            .ok_or_else(|| anyhow!("worktree not found"))?;
+
+        if worktree.host_connection_id == connection_id
+            || worktree.share.as_ref().map_or(false, |share| {
+                share.guest_connection_ids.contains_key(&connection_id)
+            })
+        {
+            Ok(worktree)
+        } else {
+            Err(anyhow!(
+                "{} is not a member of worktree {}",
+                connection_id,
+                worktree_id
+            ))?
+        }
+    }
+
+    #[cfg(test)]
+    fn check_invariants(&self) {
+        for (connection_id, connection) in &self.connections {
+            for worktree_id in &connection.worktrees {
+                let worktree = &self.worktrees.get(&worktree_id).unwrap();
+                if worktree.host_connection_id != *connection_id {
+                    assert!(worktree
+                        .share()
+                        .unwrap()
+                        .guest_connection_ids
+                        .contains_key(connection_id));
+                }
+            }
+            for channel_id in &connection.channels {
+                let channel = self.channels.get(channel_id).unwrap();
+                assert!(channel.connection_ids.contains(connection_id));
+            }
+            assert!(self
+                .connections_by_user_id
+                .get(&connection.user_id)
+                .unwrap()
+                .contains(connection_id));
+        }
+
+        for (user_id, connection_ids) in &self.connections_by_user_id {
+            for connection_id in connection_ids {
+                assert_eq!(
+                    self.connections.get(connection_id).unwrap().user_id,
+                    *user_id
+                );
+            }
+        }
+
+        for (worktree_id, worktree) in &self.worktrees {
+            let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
+            assert!(host_connection.worktrees.contains(worktree_id));
+
+            for collaborator_id in &worktree.collaborator_user_ids {
+                let visible_worktree_ids = self
+                    .visible_worktrees_by_user_id
+                    .get(collaborator_id)
+                    .unwrap();
+                assert!(visible_worktree_ids.contains(worktree_id));
+            }
+
+            if let Some(share) = &worktree.share {
+                for guest_connection_id in share.guest_connection_ids.keys() {
+                    let guest_connection = self.connections.get(guest_connection_id).unwrap();
+                    assert!(guest_connection.worktrees.contains(worktree_id));
+                }
+                assert_eq!(
+                    share.active_replica_ids.len(),
+                    share.guest_connection_ids.len(),
+                );
+                assert_eq!(
+                    share.active_replica_ids,
+                    share
+                        .guest_connection_ids
+                        .values()
+                        .copied()
+                        .collect::<HashSet<_>>(),
+                );
+            }
+        }
+
+        for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
+            for worktree_id in visible_worktree_ids {
+                let worktree = self.worktrees.get(worktree_id).unwrap();
+                assert!(worktree.collaborator_user_ids.contains(user_id));
+            }
+        }
+
+        for (channel_id, channel) in &self.channels {
+            for connection_id in &channel.connection_ids {
+                let connection = self.connections.get(connection_id).unwrap();
+                assert!(connection.channels.contains(channel_id));
+            }
+        }
+    }
+}
+
+impl Worktree {
+    pub fn connection_ids(&self) -> Vec<ConnectionId> {
+        if let Some(share) = &self.share {
+            share
+                .guest_connection_ids
+                .keys()
+                .copied()
+                .chain(Some(self.host_connection_id))
+                .collect()
+        } else {
+            vec![self.host_connection_id]
+        }
+    }
+
+    pub fn share(&self) -> tide::Result<&WorktreeShare> {
+        Ok(self
+            .share
+            .as_ref()
+            .ok_or_else(|| anyhow!("worktree is not shared"))?)
+    }
+
+    fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
+        Ok(self
+            .share
+            .as_mut()
+            .ok_or_else(|| anyhow!("worktree is not shared"))?)
+    }
+}
+
+impl Channel {
+    fn connection_ids(&self) -> Vec<ConnectionId> {
+        self.connection_ids.iter().copied().collect()
+    }
+}

zed/src/worktree.rs 🔗

@@ -66,21 +66,28 @@ impl Entity for Worktree {
     type Event = ();
 
     fn release(&mut self, cx: &mut MutableAppContext) {
-        let rpc = match self {
-            Self::Local(tree) => tree
-                .remote_id
-                .borrow()
-                .map(|remote_id| (tree.rpc.clone(), remote_id)),
-            Self::Remote(tree) => Some((tree.rpc.clone(), tree.remote_id)),
-        };
-
-        if let Some((rpc, worktree_id)) = rpc {
-            cx.spawn(|_| async move {
-                if let Err(err) = rpc.send(proto::CloseWorktree { worktree_id }).await {
-                    log::error!("error closing worktree {}: {}", worktree_id, err);
+        match self {
+            Self::Local(tree) => {
+                if let Some(worktree_id) = *tree.remote_id.borrow() {
+                    let rpc = tree.rpc.clone();
+                    cx.spawn(|_| async move {
+                        if let Err(err) = rpc.send(proto::CloseWorktree { worktree_id }).await {
+                            log::error!("error closing worktree: {}", err);
+                        }
+                    })
+                    .detach();
                 }
-            })
-            .detach();
+            }
+            Self::Remote(tree) => {
+                let rpc = tree.rpc.clone();
+                let worktree_id = tree.remote_id;
+                cx.spawn(|_| async move {
+                    if let Err(err) = rpc.send(proto::LeaveWorktree { worktree_id }).await {
+                        log::error!("error closing worktree: {}", err);
+                    }
+                })
+                .detach();
+            }
         }
     }
 }

zrpc/proto/zed.proto 🔗

@@ -39,6 +39,7 @@ message Envelope {
         OpenWorktreeResponse open_worktree_response = 34;
         UnshareWorktree unshare_worktree = 35;
         UpdateCollaborators update_collaborators = 36;
+        LeaveWorktree leave_worktree = 37;
     }
 }
 
@@ -75,6 +76,10 @@ message JoinWorktree {
     uint64 worktree_id = 1;
 }
 
+message LeaveWorktree {
+    uint64 worktree_id = 1;
+}
+
 message JoinWorktreeResponse {
     Worktree worktree = 2;
     uint32 replica_id = 3;

zrpc/src/proto.rs 🔗

@@ -139,6 +139,7 @@ messages!(
     JoinWorktree,
     JoinWorktreeResponse,
     LeaveChannel,
+    LeaveWorktree,
     OpenBuffer,
     OpenBufferResponse,
     OpenWorktree,