@@ -1,3 +1,5 @@
+mod store;
+
use super::{
auth,
db::{ChannelId, MessageId, UserId},
@@ -8,16 +10,17 @@ use anyhow::anyhow;
use async_std::{sync::RwLock, task};
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use futures::{future::BoxFuture, FutureExt};
-use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
+use postage::{broadcast, mpsc, prelude::Sink as _, prelude::Stream as _};
use sha1::{Digest as _, Sha1};
use std::{
any::TypeId,
- collections::{hash_map, HashMap, HashSet},
+ collections::{HashMap, HashSet},
future::Future,
mem,
sync::Arc,
time::Instant,
};
+use store::{ReplicaId, Store, Worktree};
use surf::StatusCode;
use tide::log;
use tide::{
@@ -30,8 +33,6 @@ use zrpc::{
Connection, ConnectionId, Peer, TypedEnvelope,
};
-type ReplicaId = u16;
-
type MessageHandler = Box<
dyn Send
+ Sync
@@ -40,46 +41,12 @@ type MessageHandler = Box<
pub struct Server {
peer: Arc<Peer>,
- state: RwLock<ServerState>,
+ store: RwLock<Store>,
app_state: Arc<AppState>,
handlers: HashMap<TypeId, MessageHandler>,
notifications: Option<mpsc::Sender<()>>,
}
-#[derive(Default)]
-struct ServerState {
- connections: HashMap<ConnectionId, ConnectionState>,
- connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
- worktrees: HashMap<u64, Worktree>,
- visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
- channels: HashMap<ChannelId, Channel>,
- next_worktree_id: u64,
-}
-
-struct ConnectionState {
- user_id: UserId,
- worktrees: HashSet<u64>,
- channels: HashSet<ChannelId>,
-}
-
-struct Worktree {
- host_connection_id: ConnectionId,
- collaborator_user_ids: Vec<UserId>,
- root_name: String,
- share: Option<WorktreeShare>,
-}
-
-struct WorktreeShare {
- guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
- active_replica_ids: HashSet<ReplicaId>,
- entries: HashMap<u64, proto::Entry>,
-}
-
-#[derive(Default)]
-struct Channel {
- connection_ids: HashSet<ConnectionId>,
-}
-
const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024;
@@ -92,7 +59,7 @@ impl Server {
let mut server = Self {
peer,
app_state,
- state: Default::default(),
+ store: Default::default(),
handlers: Default::default(),
notifications,
};
@@ -100,7 +67,7 @@ impl Server {
server
.add_handler(Server::ping)
.add_handler(Server::open_worktree)
- .add_handler(Server::handle_close_worktree)
+ .add_handler(Server::close_worktree)
.add_handler(Server::share_worktree)
.add_handler(Server::unshare_worktree)
.add_handler(Server::join_worktree)
@@ -149,7 +116,10 @@ impl Server {
async move {
let (connection_id, handle_io, mut incoming_rx) =
this.peer.add_connection(connection).await;
- this.add_connection(connection_id, user_id).await;
+ this.store
+ .write()
+ .await
+ .add_connection(connection_id, user_id);
if let Err(err) = this.update_collaborators_for_users(&[user_id]).await {
log::error!("error updating collaborators for {:?}: {}", user_id, err);
}
@@ -197,61 +167,40 @@ impl Server {
}
}
- async fn sign_out(self: &Arc<Self>, connection_id: zrpc::ConnectionId) -> tide::Result<()> {
+ async fn sign_out(self: &Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
self.peer.disconnect(connection_id).await;
- self.remove_connection(connection_id).await?;
- Ok(())
- }
+ let removed_connection = self.store.write().await.remove_connection(connection_id)?;
- // Add a new connection associated with a given user.
- async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) {
- let mut state = self.state.write().await;
- state.connections.insert(
- connection_id,
- ConnectionState {
- user_id,
- worktrees: Default::default(),
- channels: Default::default(),
- },
- );
- state
- .connections_by_user_id
- .entry(user_id)
- .or_default()
- .insert(connection_id);
- }
-
- // Remove the given connection and its association with any worktrees.
- async fn remove_connection(
- self: &Arc<Server>,
- connection_id: ConnectionId,
- ) -> tide::Result<()> {
- let mut worktree_ids = Vec::new();
- let mut state = self.state.write().await;
- if let Some(connection) = state.connections.remove(&connection_id) {
- worktree_ids = connection.worktrees.into_iter().collect();
-
- for channel_id in connection.channels {
- if let Some(channel) = state.channels.get_mut(&channel_id) {
- channel.connection_ids.remove(&connection_id);
- }
- }
-
- let user_connections = state
- .connections_by_user_id
- .get_mut(&connection.user_id)
- .unwrap();
- user_connections.remove(&connection_id);
- if user_connections.is_empty() {
- state.connections_by_user_id.remove(&connection.user_id);
+ for (worktree_id, worktree) in removed_connection.hosted_worktrees {
+ if let Some(share) = worktree.share {
+ broadcast(
+ connection_id,
+ share.guest_connection_ids.keys().copied().collect(),
+ |conn_id| {
+ self.peer
+ .send(conn_id, proto::UnshareWorktree { worktree_id })
+ },
+ )
+ .await?;
}
}
- drop(state);
- for worktree_id in worktree_ids {
- self.close_worktree(worktree_id, connection_id).await?;
+ for (worktree_id, peer_ids) in removed_connection.guest_worktree_ids {
+ broadcast(connection_id, peer_ids, |conn_id| {
+ self.peer.send(
+ conn_id,
+ proto::RemovePeer {
+ worktree_id,
+ peer_id: connection_id.0,
+ },
+ )
+ })
+ .await?;
}
+ self.update_collaborators_for_users(removed_connection.collaborator_ids.iter())
+ .await;
+
Ok(())
}
@@ -266,7 +215,7 @@ impl Server {
) -> tide::Result<()> {
let receipt = request.receipt();
let host_user_id = self
- .state
+ .store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -289,7 +238,7 @@ impl Server {
}
let collaborator_user_ids = collaborator_user_ids.into_iter().collect::<Vec<_>>();
- let worktree_id = self.state.write().await.add_worktree(Worktree {
+ let worktree_id = self.store.write().await.add_worktree(Worktree {
host_connection_id: request.sender_id,
collaborator_user_ids: collaborator_user_ids.clone(),
root_name: request.payload.root_name,
@@ -305,6 +254,33 @@ impl Server {
Ok(())
}
+ async fn close_worktree(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::CloseWorktree>,
+ ) -> tide::Result<()> {
+ let worktree_id = request.payload.worktree_id;
+ let worktree = self
+ .store
+ .write()
+ .await
+ .remove_worktree(worktree_id, request.sender_id)?;
+
+ if let Some(share) = worktree.share {
+ broadcast(
+ request.sender_id,
+ share.guest_connection_ids.keys().copied().collect(),
+ |conn_id| {
+ self.peer
+ .send(conn_id, proto::UnshareWorktree { worktree_id })
+ },
+ )
+ .await?;
+ }
+ self.update_collaborators_for_users(&worktree.collaborator_user_ids)
+ .await?;
+ Ok(())
+ }
+
async fn share_worktree(
self: Arc<Server>,
mut request: TypedEnvelope<proto::ShareWorktree>,
@@ -319,16 +295,12 @@ impl Server {
.map(|entry| (entry.id, entry))
.collect();
- let mut state = self.state.write().await;
- if let Some(worktree) = state.worktrees.get_mut(&worktree.id) {
- worktree.share = Some(WorktreeShare {
- guest_connection_ids: Default::default(),
- active_replica_ids: Default::default(),
- entries,
- });
- let collaborator_user_ids = worktree.collaborator_user_ids.clone();
-
- drop(state);
+ if let Some(collaborator_user_ids) =
+ self.store
+ .write()
+ .await
+ .share_worktree(worktree.id, request.sender_id, entries)
+ {
self.peer
.respond(request.receipt(), proto::ShareWorktreeResponse {})
.await?;
@@ -352,26 +324,11 @@ impl Server {
request: TypedEnvelope<proto::UnshareWorktree>,
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
-
- let connection_ids;
- let collaborator_user_ids;
- {
- let mut state = self.state.write().await;
- let worktree = state.write_worktree(worktree_id, request.sender_id)?;
- if worktree.host_connection_id != request.sender_id {
- return Err(anyhow!("no such worktree"))?;
- }
-
- connection_ids = worktree.connection_ids();
- collaborator_user_ids = worktree.collaborator_user_ids.clone();
-
- worktree.share.take();
- for connection_id in &connection_ids {
- if let Some(connection) = state.connections.get_mut(connection_id) {
- connection.worktrees.remove(&worktree_id);
- }
- }
- }
+ let (connection_ids, collaborator_user_ids) = self
+ .store
+ .write()
+ .await
+ .unshare_worktree(worktree_id, request.sender_id)?;
broadcast(request.sender_id, connection_ids, |conn_id| {
self.peer
@@ -390,7 +347,7 @@ impl Server {
) -> tide::Result<()> {
let worktree_id = request.payload.worktree_id;
let user_id = self
- .state
+ .store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -398,7 +355,7 @@ impl Server {
let response;
let connection_ids;
let collaborator_user_ids;
- let mut state = self.state.write().await;
+ let mut state = self.store.write().await;
match state.join_worktree(request.sender_id, user_id, worktree_id) {
Ok((peer_replica_id, worktree)) => {
let share = worktree.share()?;
@@ -462,48 +419,17 @@ impl Server {
Ok(())
}
- async fn handle_close_worktree(
- self: Arc<Server>,
- request: TypedEnvelope<proto::CloseWorktree>,
- ) -> tide::Result<()> {
- self.close_worktree(request.payload.worktree_id, request.sender_id)
- .await
- }
-
- async fn close_worktree(
+ async fn leave_worktree(
self: &Arc<Server>,
worktree_id: u64,
sender_conn_id: ConnectionId,
) -> tide::Result<()> {
- let connection_ids;
- let collaborator_user_ids;
- let mut is_host = false;
- let mut is_guest = false;
+ if let Some((connection_ids, collaborator_ids)) = self
+ .store
+ .write()
+ .await
+ .leave_worktree(sender_conn_id, worktree_id)
{
- let mut state = self.state.write().await;
- let worktree = state.write_worktree(worktree_id, sender_conn_id)?;
- connection_ids = worktree.connection_ids();
- collaborator_user_ids = worktree.collaborator_user_ids.clone();
-
- if worktree.host_connection_id == sender_conn_id {
- is_host = true;
- state.remove_worktree(worktree_id);
- } else {
- let share = worktree.share_mut()?;
- if let Some(replica_id) = share.guest_connection_ids.remove(&sender_conn_id) {
- is_guest = true;
- share.active_replica_ids.remove(&replica_id);
- }
- }
- }
-
- if is_host {
- broadcast(sender_conn_id, connection_ids, |conn_id| {
- self.peer
- .send(conn_id, proto::UnshareWorktree { worktree_id })
- })
- .await?;
- } else if is_guest {
broadcast(sender_conn_id, connection_ids, |conn_id| {
self.peer.send(
conn_id,
@@ -513,10 +439,10 @@ impl Server {
},
)
})
- .await?
- }
- self.update_collaborators_for_users(&collaborator_user_ids)
.await?;
+ self.update_collaborators_for_users(&collaborator_ids)
+ .await?;
+ }
Ok(())
}
@@ -524,22 +450,19 @@ impl Server {
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktree>,
) -> tide::Result<()> {
- {
- let mut state = self.state.write().await;
- let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?;
- let share = worktree.share_mut()?;
-
- for entry_id in &request.payload.removed_entries {
- share.entries.remove(&entry_id);
- }
-
- for entry in &request.payload.updated_entries {
- share.entries.insert(entry.id, entry.clone());
- }
- }
+ let connection_ids = self.store.write().await.update_worktree(
+ request.sender_id,
+ request.payload.worktree_id,
+ &request.payload.removed_entries,
+ &request.payload.updated_entries,
+ )?;
+
+ broadcast(request.sender_id, connection_ids, |connection_id| {
+ self.peer
+ .forward_send(request.sender_id, connection_id, request.payload.clone())
+ })
+ .await?;
- self.broadcast_in_worktree(request.payload.worktree_id, &request)
- .await?;
Ok(())
}
@@ -548,14 +471,11 @@ impl Server {
request: TypedEnvelope<proto::OpenBuffer>,
) -> tide::Result<()> {
let receipt = request.receipt();
- let worktree_id = request.payload.worktree_id;
let host_connection_id = self
- .state
+ .store
.read()
.await
- .read_worktree(worktree_id, request.sender_id)?
- .host_connection_id;
-
+ .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
let response = self
.peer
.forward_request(request.sender_id, host_connection_id, request.payload)
@@ -569,16 +489,13 @@ impl Server {
request: TypedEnvelope<proto::CloseBuffer>,
) -> tide::Result<()> {
let host_connection_id = self
- .state
+ .store
.read()
.await
- .read_worktree(request.payload.worktree_id, request.sender_id)?
- .host_connection_id;
-
+ .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
self.peer
.forward_send(request.sender_id, host_connection_id, request.payload)
.await?;
-
Ok(())
}
@@ -589,15 +506,11 @@ impl Server {
let host;
let guests;
{
- let state = self.state.read().await;
- let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?;
- host = worktree.host_connection_id;
- guests = worktree
- .share()?
- .guest_connection_ids
- .keys()
- .copied()
- .collect::<Vec<_>>();
+ let state = self.store.read().await;
+ host = state
+ .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?;
+ guests = state
+ .worktree_guest_connection_ids(request.sender_id, request.payload.worktree_id)?;
}
let sender = request.sender_id;
@@ -627,8 +540,18 @@ impl Server {
self: Arc<Server>,
request: TypedEnvelope<proto::UpdateBuffer>,
) -> tide::Result<()> {
- self.broadcast_in_worktree(request.payload.worktree_id, &request)
- .await?;
+ broadcast(
+ request.sender_id,
+ self.store
+ .read()
+ .await
+ .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
+ |connection_id| {
+ self.peer
+ .forward_send(request.sender_id, connection_id, request.payload.clone())
+ },
+ )
+ .await?;
self.peer.respond(request.receipt(), proto::Ack {}).await?;
Ok(())
}
@@ -637,8 +560,19 @@ impl Server {
self: Arc<Server>,
request: TypedEnvelope<proto::BufferSaved>,
) -> tide::Result<()> {
- self.broadcast_in_worktree(request.payload.worktree_id, &request)
- .await
+ broadcast(
+ request.sender_id,
+ self.store
+ .read()
+ .await
+ .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?,
+ |connection_id| {
+ self.peer
+ .forward_send(request.sender_id, connection_id, request.payload.clone())
+ },
+ )
+ .await?;
+ Ok(())
}
async fn get_channels(
@@ -646,7 +580,7 @@ impl Server {
request: TypedEnvelope<proto::GetChannels>,
) -> tide::Result<()> {
let user_id = self
- .state
+ .store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -698,45 +632,10 @@ impl Server {
) -> tide::Result<()> {
let mut send_futures = Vec::new();
- let state = self.state.read().await;
+ let state = self.store.read().await;
for user_id in user_ids {
- let mut collaborators = HashMap::new();
- for worktree_id in state
- .visible_worktrees_by_user_id
- .get(&user_id)
- .unwrap_or(&HashSet::new())
- {
- let worktree = &state.worktrees[worktree_id];
-
- let mut guests = HashSet::new();
- if let Ok(share) = worktree.share() {
- for guest_connection_id in share.guest_connection_ids.keys() {
- let user_id = state
- .user_id_for_connection(*guest_connection_id)
- .context("stale worktree guest connection")?;
- guests.insert(user_id.to_proto());
- }
- }
-
- let host_user_id = state
- .user_id_for_connection(worktree.host_connection_id)
- .context("stale worktree host connection")?;
- let host =
- collaborators
- .entry(host_user_id)
- .or_insert_with(|| proto::Collaborator {
- user_id: host_user_id.to_proto(),
- worktrees: Vec::new(),
- });
- host.worktrees.push(proto::WorktreeMetadata {
- root_name: worktree.root_name.clone(),
- is_shared: worktree.share().is_ok(),
- participants: guests.into_iter().collect(),
- });
- }
-
- let collaborators = collaborators.into_values().collect::<Vec<_>>();
- for connection_id in state.user_connection_ids(*user_id) {
+ let collaborators = state.collaborators_for_user(*user_id);
+ for connection_id in state.connection_ids_for_user(*user_id) {
send_futures.push(self.peer.send(
connection_id,
proto::UpdateCollaborators {
@@ -757,7 +656,7 @@ impl Server {
request: TypedEnvelope<proto::JoinChannel>,
) -> tide::Result<()> {
let user_id = self
- .state
+ .store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -771,7 +670,7 @@ impl Server {
Err(anyhow!("access denied"))?;
}
- self.state
+ self.store
.write()
.await
.join_channel(request.sender_id, channel_id);
@@ -806,7 +705,7 @@ impl Server {
request: TypedEnvelope<proto::LeaveChannel>,
) -> tide::Result<()> {
let user_id = self
- .state
+ .store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -820,7 +719,7 @@ impl Server {
Err(anyhow!("access denied"))?;
}
- self.state
+ self.store
.write()
.await
.leave_channel(request.sender_id, channel_id);
@@ -837,10 +736,10 @@ impl Server {
let user_id;
let connection_ids;
{
- let state = self.state.read().await;
+ let state = self.store.read().await;
user_id = state.user_id_for_connection(request.sender_id)?;
- if let Some(channel) = state.channels.get(&channel_id) {
- connection_ids = channel.connection_ids();
+ if let Some(ids) = state.channel_connection_ids(channel_id) {
+ connection_ids = ids;
} else {
return Ok(());
}
@@ -925,7 +824,7 @@ impl Server {
request: TypedEnvelope<proto::GetChannelMessages>,
) -> tide::Result<()> {
let user_id = self
- .state
+ .store
.read()
.await
.user_id_for_connection(request.sender_id)?;
@@ -968,27 +867,6 @@ impl Server {
.await?;
Ok(())
}
-
- async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
- &self,
- worktree_id: u64,
- message: &TypedEnvelope<T>,
- ) -> tide::Result<()> {
- let connection_ids = self
- .state
- .read()
- .await
- .read_worktree(worktree_id, message.sender_id)?
- .connection_ids();
-
- broadcast(message.sender_id, connection_ids, |conn_id| {
- self.peer
- .forward_send(message.sender_id, conn_id, message.payload.clone())
- })
- .await?;
-
- Ok(())
- }
}
pub async fn broadcast<F, T>(
@@ -1008,292 +886,6 @@ where
Ok(())
}
-impl ServerState {
- fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
- if let Some(connection) = self.connections.get_mut(&connection_id) {
- connection.channels.insert(channel_id);
- self.channels
- .entry(channel_id)
- .or_default()
- .connection_ids
- .insert(connection_id);
- }
- }
-
- fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
- if let Some(connection) = self.connections.get_mut(&connection_id) {
- connection.channels.remove(&channel_id);
- if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
- entry.get_mut().connection_ids.remove(&connection_id);
- if entry.get_mut().connection_ids.is_empty() {
- entry.remove();
- }
- }
- }
- }
-
- fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
- Ok(self
- .connections
- .get(&connection_id)
- .ok_or_else(|| anyhow!("unknown connection"))?
- .user_id)
- }
-
- fn user_connection_ids<'a>(
- &'a self,
- user_id: UserId,
- ) -> impl 'a + Iterator<Item = ConnectionId> {
- self.connections_by_user_id
- .get(&user_id)
- .into_iter()
- .flatten()
- .copied()
- }
-
- // Add the given connection as a guest of the given worktree
- fn join_worktree(
- &mut self,
- connection_id: ConnectionId,
- user_id: UserId,
- worktree_id: u64,
- ) -> tide::Result<(ReplicaId, &Worktree)> {
- let connection = self
- .connections
- .get_mut(&connection_id)
- .ok_or_else(|| anyhow!("no such connection"))?;
- let worktree = self
- .worktrees
- .get_mut(&worktree_id)
- .ok_or_else(|| anyhow!("no such worktree"))?;
- if !worktree.collaborator_user_ids.contains(&user_id) {
- Err(anyhow!("no such worktree"))?;
- }
-
- let share = worktree.share_mut()?;
- connection.worktrees.insert(worktree_id);
-
- let mut replica_id = 1;
- while share.active_replica_ids.contains(&replica_id) {
- replica_id += 1;
- }
- share.active_replica_ids.insert(replica_id);
- share.guest_connection_ids.insert(connection_id, replica_id);
- return Ok((replica_id, worktree));
- }
-
- fn read_worktree(
- &self,
- worktree_id: u64,
- connection_id: ConnectionId,
- ) -> tide::Result<&Worktree> {
- let worktree = self
- .worktrees
- .get(&worktree_id)
- .ok_or_else(|| anyhow!("worktree not found"))?;
-
- if worktree.host_connection_id == connection_id
- || worktree
- .share()?
- .guest_connection_ids
- .contains_key(&connection_id)
- {
- Ok(worktree)
- } else {
- Err(anyhow!(
- "{} is not a member of worktree {}",
- connection_id,
- worktree_id
- ))?
- }
- }
-
- fn write_worktree(
- &mut self,
- worktree_id: u64,
- connection_id: ConnectionId,
- ) -> tide::Result<&mut Worktree> {
- let worktree = self
- .worktrees
- .get_mut(&worktree_id)
- .ok_or_else(|| anyhow!("worktree not found"))?;
-
- if worktree.host_connection_id == connection_id
- || worktree.share.as_ref().map_or(false, |share| {
- share.guest_connection_ids.contains_key(&connection_id)
- })
- {
- Ok(worktree)
- } else {
- Err(anyhow!(
- "{} is not a member of worktree {}",
- connection_id,
- worktree_id
- ))?
- }
- }
-
- fn add_worktree(&mut self, worktree: Worktree) -> u64 {
- let worktree_id = self.next_worktree_id;
- for collaborator_user_id in &worktree.collaborator_user_ids {
- self.visible_worktrees_by_user_id
- .entry(*collaborator_user_id)
- .or_default()
- .insert(worktree_id);
- }
- self.next_worktree_id += 1;
- if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
- connection.worktrees.insert(worktree_id);
- }
- self.worktrees.insert(worktree_id, worktree);
-
- #[cfg(test)]
- self.check_invariants();
-
- worktree_id
- }
-
- fn remove_worktree(&mut self, worktree_id: u64) {
- let worktree = self.worktrees.remove(&worktree_id).unwrap();
- if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
- connection.worktrees.remove(&worktree_id);
- }
- if let Some(share) = worktree.share {
- for connection_id in share.guest_connection_ids.keys() {
- if let Some(connection) = self.connections.get_mut(connection_id) {
- connection.worktrees.remove(&worktree_id);
- }
- }
- }
- for collaborator_user_id in worktree.collaborator_user_ids {
- if let Some(visible_worktrees) = self
- .visible_worktrees_by_user_id
- .get_mut(&collaborator_user_id)
- {
- visible_worktrees.remove(&worktree_id);
- }
- }
-
- #[cfg(test)]
- self.check_invariants();
- }
-
- #[cfg(test)]
- fn check_invariants(&self) {
- for (connection_id, connection) in &self.connections {
- for worktree_id in &connection.worktrees {
- let worktree = &self.worktrees.get(&worktree_id).unwrap();
- if worktree.host_connection_id != *connection_id {
- assert!(worktree
- .share()
- .unwrap()
- .guest_connection_ids
- .contains_key(connection_id));
- }
- }
- for channel_id in &connection.channels {
- let channel = self.channels.get(channel_id).unwrap();
- assert!(channel.connection_ids.contains(connection_id));
- }
- assert!(self
- .connections_by_user_id
- .get(&connection.user_id)
- .unwrap()
- .contains(connection_id));
- }
-
- for (user_id, connection_ids) in &self.connections_by_user_id {
- for connection_id in connection_ids {
- assert_eq!(
- self.connections.get(connection_id).unwrap().user_id,
- *user_id
- );
- }
- }
-
- for (worktree_id, worktree) in &self.worktrees {
- let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
- assert!(host_connection.worktrees.contains(worktree_id));
-
- for collaborator_id in &worktree.collaborator_user_ids {
- let visible_worktree_ids = self
- .visible_worktrees_by_user_id
- .get(collaborator_id)
- .unwrap();
- assert!(visible_worktree_ids.contains(worktree_id));
- }
-
- if let Some(share) = &worktree.share {
- for guest_connection_id in share.guest_connection_ids.keys() {
- let guest_connection = self.connections.get(guest_connection_id).unwrap();
- assert!(guest_connection.worktrees.contains(worktree_id));
- }
- assert_eq!(
- share.active_replica_ids.len(),
- share.guest_connection_ids.len(),
- );
- assert_eq!(
- share.active_replica_ids,
- share
- .guest_connection_ids
- .values()
- .copied()
- .collect::<HashSet<_>>(),
- );
- }
- }
-
- for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
- for worktree_id in visible_worktree_ids {
- let worktree = self.worktrees.get(worktree_id).unwrap();
- assert!(worktree.collaborator_user_ids.contains(user_id));
- }
- }
-
- for (channel_id, channel) in &self.channels {
- for connection_id in &channel.connection_ids {
- let connection = self.connections.get(connection_id).unwrap();
- assert!(connection.channels.contains(channel_id));
- }
- }
- }
-}
-
-impl Worktree {
- pub fn connection_ids(&self) -> Vec<ConnectionId> {
- if let Some(share) = &self.share {
- share
- .guest_connection_ids
- .keys()
- .copied()
- .chain(Some(self.host_connection_id))
- .collect()
- } else {
- vec![self.host_connection_id]
- }
- }
-
- fn share(&self) -> tide::Result<&WorktreeShare> {
- Ok(self
- .share
- .as_ref()
- .ok_or_else(|| anyhow!("worktree is not shared"))?)
- }
-
- fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
- Ok(self
- .share
- .as_mut()
- .ok_or_else(|| anyhow!("worktree is not shared"))?)
- }
-}
-
-impl Channel {
- fn connection_ids(&self) -> Vec<ConnectionId> {
- self.connection_ids.iter().copied().collect()
- }
-}
-
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let server = Server::new(app.state().clone(), rpc.clone(), None);
app.at("/rpc").with(auth::VerifyToken).get(move |request: Request<Arc<AppState>>| {
@@ -2477,16 +2069,16 @@ mod tests {
})
}
- async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {
- self.server.state.read().await
+ async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> {
+ self.server.store.read().await
}
async fn condition<F>(&mut self, mut predicate: F)
where
- F: FnMut(&ServerState) -> bool,
+ F: FnMut(&Store) -> bool,
{
async_std::future::timeout(Duration::from_millis(500), async {
- while !(predicate)(&*self.server.state.read().await) {
+ while !(predicate)(&*self.server.store.read().await) {
self.notifications.recv().await;
}
})
@@ -0,0 +1,574 @@
+use crate::db::{ChannelId, MessageId, UserId};
+use crate::errors::TideResultExt;
+use anyhow::anyhow;
+use std::collections::{hash_map, HashMap, HashSet};
+use zrpc::{proto, ConnectionId};
+
+#[derive(Default)]
+pub struct Store {
+ connections: HashMap<ConnectionId, ConnectionState>,
+ connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
+ worktrees: HashMap<u64, Worktree>,
+ visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
+ channels: HashMap<ChannelId, Channel>,
+ next_worktree_id: u64,
+}
+
+struct ConnectionState {
+ user_id: UserId,
+ worktrees: HashSet<u64>,
+ channels: HashSet<ChannelId>,
+}
+
+pub struct Worktree {
+ pub host_connection_id: ConnectionId,
+ pub collaborator_user_ids: Vec<UserId>,
+ pub root_name: String,
+ pub share: Option<WorktreeShare>,
+}
+
+struct WorktreeShare {
+ pub guest_connection_ids: HashMap<ConnectionId, ReplicaId>,
+ pub active_replica_ids: HashSet<ReplicaId>,
+ pub entries: HashMap<u64, proto::Entry>,
+}
+
+#[derive(Default)]
+struct Channel {
+ connection_ids: HashSet<ConnectionId>,
+}
+
+pub type ReplicaId = u16;
+
+#[derive(Default)]
+pub struct RemovedConnectionState {
+ pub hosted_worktrees: HashMap<u64, Worktree>,
+ pub guest_worktree_ids: HashMap<u64, Vec<ConnectionId>>,
+ pub collaborator_ids: HashSet<UserId>,
+}
+
+impl Store {
+ pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) {
+ self.connections.insert(
+ connection_id,
+ ConnectionState {
+ user_id,
+ worktrees: Default::default(),
+ channels: Default::default(),
+ },
+ );
+ self.connections_by_user_id
+ .entry(user_id)
+ .or_default()
+ .insert(connection_id);
+ }
+
+ pub fn remove_connection(
+ &mut self,
+ connection_id: ConnectionId,
+ ) -> tide::Result<RemovedConnectionState> {
+ let connection = if let Some(connection) = self.connections.get(&connection_id) {
+ connection
+ } else {
+ return Err(anyhow!("no such connection"))?;
+ };
+
+ for channel_id in connection.channels {
+ if let Some(channel) = self.channels.get_mut(&channel_id) {
+ channel.connection_ids.remove(&connection_id);
+ }
+ }
+
+ let user_connections = self
+ .connections_by_user_id
+ .get_mut(&connection.user_id)
+ .unwrap();
+ user_connections.remove(&connection_id);
+ if user_connections.is_empty() {
+ self.connections_by_user_id.remove(&connection.user_id);
+ }
+
+ let mut result = RemovedConnectionState::default();
+ for worktree_id in connection.worktrees {
+ if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) {
+ result.hosted_worktrees.insert(worktree_id, worktree);
+ result
+ .collaborator_ids
+ .extend(worktree.collaborator_user_ids.iter().copied());
+ } else {
+ if let Some(worktree) = self.worktrees.get(&worktree_id) {
+ result
+ .guest_worktree_ids
+ .insert(worktree_id, worktree.connection_ids());
+ result
+ .collaborator_ids
+ .extend(worktree.collaborator_user_ids.iter().copied());
+ }
+ }
+ }
+
+ Ok(result)
+ }
+
+ pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
+ if let Some(connection) = self.connections.get_mut(&connection_id) {
+ connection.channels.insert(channel_id);
+ self.channels
+ .entry(channel_id)
+ .or_default()
+ .connection_ids
+ .insert(connection_id);
+ }
+ }
+
+ pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) {
+ if let Some(connection) = self.connections.get_mut(&connection_id) {
+ connection.channels.remove(&channel_id);
+ if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) {
+ entry.get_mut().connection_ids.remove(&connection_id);
+ if entry.get_mut().connection_ids.is_empty() {
+ entry.remove();
+ }
+ }
+ }
+ }
+
+ pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result<UserId> {
+ Ok(self
+ .connections
+ .get(&connection_id)
+ .ok_or_else(|| anyhow!("unknown connection"))?
+ .user_id)
+ }
+
+ pub fn connection_ids_for_user<'a>(
+ &'a self,
+ user_id: UserId,
+ ) -> impl 'a + Iterator<Item = ConnectionId> {
+ self.connections_by_user_id
+ .get(&user_id)
+ .into_iter()
+ .flatten()
+ .copied()
+ }
+
+ pub fn collaborators_for_user(&self, user_id: UserId) -> Vec<proto::Collaborator> {
+ let mut collaborators = HashMap::new();
+ for worktree_id in self
+ .visible_worktrees_by_user_id
+ .get(&user_id)
+ .unwrap_or(&HashSet::new())
+ {
+ let worktree = &self.worktrees[worktree_id];
+
+ let mut guests = HashSet::new();
+ if let Ok(share) = worktree.share() {
+ for guest_connection_id in share.guest_connection_ids.keys() {
+ if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) {
+ guests.insert(user_id.to_proto());
+ }
+ }
+ }
+
+ if let Ok(host_user_id) = self
+ .user_id_for_connection(worktree.host_connection_id)
+ .context("stale worktree host connection")
+ {
+ let host =
+ collaborators
+ .entry(host_user_id)
+ .or_insert_with(|| proto::Collaborator {
+ user_id: host_user_id.to_proto(),
+ worktrees: Vec::new(),
+ });
+ host.worktrees.push(proto::WorktreeMetadata {
+ root_name: worktree.root_name.clone(),
+ is_shared: worktree.share().is_ok(),
+ participants: guests.into_iter().collect(),
+ });
+ }
+ }
+
+ collaborators.into_values().collect()
+ }
+
+ pub fn add_worktree(&mut self, worktree: Worktree) -> u64 {
+ let worktree_id = self.next_worktree_id;
+ for collaborator_user_id in &worktree.collaborator_user_ids {
+ self.visible_worktrees_by_user_id
+ .entry(*collaborator_user_id)
+ .or_default()
+ .insert(worktree_id);
+ }
+ self.next_worktree_id += 1;
+ if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
+ connection.worktrees.insert(worktree_id);
+ }
+ self.worktrees.insert(worktree_id, worktree);
+
+ #[cfg(test)]
+ self.check_invariants();
+
+ worktree_id
+ }
+
+ pub fn remove_worktree(
+ &mut self,
+ worktree_id: u64,
+ acting_connection_id: ConnectionId,
+ ) -> tide::Result<Worktree> {
+ let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) {
+ if e.get().host_connection_id != acting_connection_id {
+ Err(anyhow!("not your worktree"))?;
+ }
+ e.remove()
+ } else {
+ return Err(anyhow!("no such worktree"))?;
+ };
+
+ if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
+ connection.worktrees.remove(&worktree_id);
+ }
+
+ if let Some(share) = worktree.share {
+ for connection_id in share.guest_connection_ids.keys() {
+ if let Some(connection) = self.connections.get_mut(connection_id) {
+ connection.worktrees.remove(&worktree_id);
+ }
+ }
+ }
+
+ for collaborator_user_id in worktree.collaborator_user_ids {
+ if let Some(visible_worktrees) = self
+ .visible_worktrees_by_user_id
+ .get_mut(&collaborator_user_id)
+ {
+ visible_worktrees.remove(&worktree_id);
+ }
+ }
+
+ #[cfg(test)]
+ self.check_invariants();
+
+ Ok(worktree)
+ }
+
+ pub fn share_worktree(
+ &mut self,
+ worktree_id: u64,
+ connection_id: ConnectionId,
+ entries: HashMap<u64, proto::Entry>,
+ ) -> Option<Vec<UserId>> {
+ if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
+ if worktree.host_connection_id == connection_id {
+ worktree.share = Some(WorktreeShare {
+ guest_connection_ids: Default::default(),
+ active_replica_ids: Default::default(),
+ entries,
+ });
+ return Some(worktree.collaborator_user_ids.clone());
+ }
+ }
+ None
+ }
+
+ pub fn unshare_worktree(
+ &mut self,
+ worktree_id: u64,
+ acting_connection_id: ConnectionId,
+ ) -> tide::Result<(Vec<ConnectionId>, Vec<UserId>)> {
+ let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) {
+ worktree
+ } else {
+ return Err(anyhow!("no such worktree"))?;
+ };
+
+ if worktree.host_connection_id != acting_connection_id {
+ return Err(anyhow!("not your worktree"))?;
+ }
+
+ let connection_ids = worktree.connection_ids();
+
+ if let Some(share) = worktree.share.take() {
+ for connection_id in &connection_ids {
+ if let Some(connection) = self.connections.get_mut(connection_id) {
+ connection.worktrees.remove(&worktree_id);
+ }
+ }
+ Ok((connection_ids, worktree.collaborator_user_ids.clone()))
+ } else {
+ Err(anyhow!("worktree is not shared"))?
+ }
+ }
+
+ pub fn join_worktree(
+ &mut self,
+ connection_id: ConnectionId,
+ user_id: UserId,
+ worktree_id: u64,
+ ) -> tide::Result<(ReplicaId, &Worktree)> {
+ let connection = self
+ .connections
+ .get_mut(&connection_id)
+ .ok_or_else(|| anyhow!("no such connection"))?;
+ let worktree = self
+ .worktrees
+ .get_mut(&worktree_id)
+ .and_then(|worktree| {
+ if worktree.collaborator_user_ids.contains(&user_id) {
+ Some(worktree)
+ } else {
+ None
+ }
+ })
+ .ok_or_else(|| anyhow!("no such worktree"))?;
+
+ let share = worktree.share_mut()?;
+ connection.worktrees.insert(worktree_id);
+
+ let mut replica_id = 1;
+ while share.active_replica_ids.contains(&replica_id) {
+ replica_id += 1;
+ }
+ share.active_replica_ids.insert(replica_id);
+ share.guest_connection_ids.insert(connection_id, replica_id);
+ return Ok((replica_id, worktree));
+ }
+
+ pub fn leave_worktree(
+ &mut self,
+ connection_id: ConnectionId,
+ worktree_id: u64,
+ ) -> Option<(Vec<ConnectionId>, Vec<UserId>)> {
+ let worktree = self.worktrees.get_mut(&worktree_id)?;
+ let share = worktree.share.as_mut()?;
+ let replica_id = share.guest_connection_ids.remove(&connection_id)?;
+ share.active_replica_ids.remove(&replica_id);
+ Some((
+ worktree.connection_ids(),
+ worktree.collaborator_user_ids.clone(),
+ ))
+ }
+
+ pub fn update_worktree(
+ &mut self,
+ connection_id: ConnectionId,
+ worktree_id: u64,
+ removed_entries: &[u64],
+ updated_entries: &[proto::Entry],
+ ) -> tide::Result<Vec<ConnectionId>> {
+ let worktree = self.write_worktree(worktree_id, connection_id)?;
+ let share = worktree.share_mut()?;
+ for entry_id in removed_entries {
+ share.entries.remove(&entry_id);
+ }
+ for entry in updated_entries {
+ share.entries.insert(entry.id, entry.clone());
+ }
+ Ok(worktree.connection_ids())
+ }
+
+ pub fn worktree_host_connection_id(
+ &self,
+ connection_id: ConnectionId,
+ worktree_id: u64,
+ ) -> tide::Result<ConnectionId> {
+ Ok(self
+ .read_worktree(worktree_id, connection_id)?
+ .host_connection_id)
+ }
+
+ pub fn worktree_guest_connection_ids(
+ &self,
+ connection_id: ConnectionId,
+ worktree_id: u64,
+ ) -> tide::Result<Vec<ConnectionId>> {
+ Ok(self
+ .read_worktree(worktree_id, connection_id)?
+ .share()?
+ .guest_connection_ids
+ .keys()
+ .copied()
+ .collect())
+ }
+
+ pub fn worktree_connection_ids(
+ &self,
+ connection_id: ConnectionId,
+ worktree_id: u64,
+ ) -> tide::Result<Vec<ConnectionId>> {
+ Ok(self
+ .read_worktree(worktree_id, connection_id)?
+ .connection_ids())
+ }
+
+ pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option<Vec<ConnectionId>> {
+ Some(self.channels.get(&channel_id)?.connection_ids())
+ }
+
+ fn read_worktree(
+ &self,
+ worktree_id: u64,
+ connection_id: ConnectionId,
+ ) -> tide::Result<&Worktree> {
+ let worktree = self
+ .worktrees
+ .get(&worktree_id)
+ .ok_or_else(|| anyhow!("worktree not found"))?;
+
+ if worktree.host_connection_id == connection_id
+ || worktree
+ .share()?
+ .guest_connection_ids
+ .contains_key(&connection_id)
+ {
+ Ok(worktree)
+ } else {
+ Err(anyhow!(
+ "{} is not a member of worktree {}",
+ connection_id,
+ worktree_id
+ ))?
+ }
+ }
+
+ fn write_worktree(
+ &mut self,
+ worktree_id: u64,
+ connection_id: ConnectionId,
+ ) -> tide::Result<&mut Worktree> {
+ let worktree = self
+ .worktrees
+ .get_mut(&worktree_id)
+ .ok_or_else(|| anyhow!("worktree not found"))?;
+
+ if worktree.host_connection_id == connection_id
+ || worktree.share.as_ref().map_or(false, |share| {
+ share.guest_connection_ids.contains_key(&connection_id)
+ })
+ {
+ Ok(worktree)
+ } else {
+ Err(anyhow!(
+ "{} is not a member of worktree {}",
+ connection_id,
+ worktree_id
+ ))?
+ }
+ }
+
+ #[cfg(test)]
+ fn check_invariants(&self) {
+ for (connection_id, connection) in &self.connections {
+ for worktree_id in &connection.worktrees {
+ let worktree = &self.worktrees.get(&worktree_id).unwrap();
+ if worktree.host_connection_id != *connection_id {
+ assert!(worktree
+ .share()
+ .unwrap()
+ .guest_connection_ids
+ .contains_key(connection_id));
+ }
+ }
+ for channel_id in &connection.channels {
+ let channel = self.channels.get(channel_id).unwrap();
+ assert!(channel.connection_ids.contains(connection_id));
+ }
+ assert!(self
+ .connections_by_user_id
+ .get(&connection.user_id)
+ .unwrap()
+ .contains(connection_id));
+ }
+
+ for (user_id, connection_ids) in &self.connections_by_user_id {
+ for connection_id in connection_ids {
+ assert_eq!(
+ self.connections.get(connection_id).unwrap().user_id,
+ *user_id
+ );
+ }
+ }
+
+ for (worktree_id, worktree) in &self.worktrees {
+ let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
+ assert!(host_connection.worktrees.contains(worktree_id));
+
+ for collaborator_id in &worktree.collaborator_user_ids {
+ let visible_worktree_ids = self
+ .visible_worktrees_by_user_id
+ .get(collaborator_id)
+ .unwrap();
+ assert!(visible_worktree_ids.contains(worktree_id));
+ }
+
+ if let Some(share) = &worktree.share {
+ for guest_connection_id in share.guest_connection_ids.keys() {
+ let guest_connection = self.connections.get(guest_connection_id).unwrap();
+ assert!(guest_connection.worktrees.contains(worktree_id));
+ }
+ assert_eq!(
+ share.active_replica_ids.len(),
+ share.guest_connection_ids.len(),
+ );
+ assert_eq!(
+ share.active_replica_ids,
+ share
+ .guest_connection_ids
+ .values()
+ .copied()
+ .collect::<HashSet<_>>(),
+ );
+ }
+ }
+
+ for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
+ for worktree_id in visible_worktree_ids {
+ let worktree = self.worktrees.get(worktree_id).unwrap();
+ assert!(worktree.collaborator_user_ids.contains(user_id));
+ }
+ }
+
+ for (channel_id, channel) in &self.channels {
+ for connection_id in &channel.connection_ids {
+ let connection = self.connections.get(connection_id).unwrap();
+ assert!(connection.channels.contains(channel_id));
+ }
+ }
+ }
+}
+
+impl Worktree {
+ pub fn connection_ids(&self) -> Vec<ConnectionId> {
+ if let Some(share) = &self.share {
+ share
+ .guest_connection_ids
+ .keys()
+ .copied()
+ .chain(Some(self.host_connection_id))
+ .collect()
+ } else {
+ vec![self.host_connection_id]
+ }
+ }
+
+ pub fn share(&self) -> tide::Result<&WorktreeShare> {
+ Ok(self
+ .share
+ .as_ref()
+ .ok_or_else(|| anyhow!("worktree is not shared"))?)
+ }
+
+ fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> {
+ Ok(self
+ .share
+ .as_mut()
+ .ok_or_else(|| anyhow!("worktree is not shared"))?)
+ }
+}
+
+impl Channel {
+ fn connection_ids(&self) -> Vec<ConnectionId> {
+ self.connection_ids.iter().copied().collect()
+ }
+}