@@ -3,6 +3,7 @@ use super::{
db::{ChannelId, MessageId, UserId},
AppState,
};
+use crate::errors::TideResultExt;
use anyhow::anyhow;
use async_std::{sync::RwLock, task};
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
@@ -49,7 +50,7 @@ pub struct Server {
struct ServerState {
connections: HashMap<ConnectionId, ConnectionState>,
connections_by_user_id: HashMap<UserId, HashSet<ConnectionId>>,
- pub worktrees: HashMap<u64, Worktree>,
+ worktrees: HashMap<u64, Worktree>,
visible_worktrees_by_user_id: HashMap<UserId, HashSet<u64>>,
channels: HashMap<ChannelId, Channel>,
next_worktree_id: u64,
@@ -707,15 +708,19 @@ impl Server {
{
let worktree = &state.worktrees[worktree_id];
- let mut participants = HashSet::new();
+ let mut guests = HashSet::new();
if let Ok(share) = worktree.share() {
for guest_connection_id in share.guest_connection_ids.keys() {
- let user_id = state.user_id_for_connection(*guest_connection_id)?;
- participants.insert(user_id.to_proto());
+ let user_id = state
+ .user_id_for_connection(*guest_connection_id)
+ .context("stale worktree guest connection")?;
+ guests.insert(user_id.to_proto());
}
}
- let host_user_id = state.user_id_for_connection(worktree.host_connection_id)?;
+ let host_user_id = state
+ .user_id_for_connection(worktree.host_connection_id)
+ .context("stale worktree host connection")?;
let host =
collaborators
.entry(host_user_id)
@@ -726,7 +731,7 @@ impl Server {
host.worktrees.push(proto::WorktreeMetadata {
root_name: worktree.root_name.clone(),
is_shared: worktree.share().is_ok(),
- participants: participants.into_iter().collect(),
+ participants: guests.into_iter().collect(),
});
}
@@ -1137,7 +1142,14 @@ impl ServerState {
.insert(worktree_id);
}
self.next_worktree_id += 1;
+ if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) {
+ connection.worktrees.insert(worktree_id);
+ }
self.worktrees.insert(worktree_id, worktree);
+
+ #[cfg(test)]
+ self.check_invariants();
+
worktree_id
}
@@ -1161,6 +1173,89 @@ impl ServerState {
visible_worktrees.remove(&worktree_id);
}
}
+
+ #[cfg(test)]
+ self.check_invariants();
+ }
+
+ #[cfg(test)]
+ fn check_invariants(&self) {
+ for (connection_id, connection) in &self.connections {
+ for worktree_id in &connection.worktrees {
+ let worktree = &self.worktrees.get(&worktree_id).unwrap();
+ if worktree.host_connection_id != *connection_id {
+ assert!(worktree
+ .share()
+ .unwrap()
+ .guest_connection_ids
+ .contains_key(connection_id));
+ }
+ }
+ for channel_id in &connection.channels {
+ let channel = self.channels.get(channel_id).unwrap();
+ assert!(channel.connection_ids.contains(connection_id));
+ }
+ assert!(self
+ .connections_by_user_id
+ .get(&connection.user_id)
+ .unwrap()
+ .contains(connection_id));
+ }
+
+ for (user_id, connection_ids) in &self.connections_by_user_id {
+ for connection_id in connection_ids {
+ assert_eq!(
+ self.connections.get(connection_id).unwrap().user_id,
+ *user_id
+ );
+ }
+ }
+
+ for (worktree_id, worktree) in &self.worktrees {
+ let host_connection = self.connections.get(&worktree.host_connection_id).unwrap();
+ assert!(host_connection.worktrees.contains(worktree_id));
+
+ for collaborator_id in &worktree.collaborator_user_ids {
+ let visible_worktree_ids = self
+ .visible_worktrees_by_user_id
+ .get(collaborator_id)
+ .unwrap();
+ assert!(visible_worktree_ids.contains(worktree_id));
+ }
+
+ if let Some(share) = &worktree.share {
+ for guest_connection_id in share.guest_connection_ids.keys() {
+ let guest_connection = self.connections.get(guest_connection_id).unwrap();
+ assert!(guest_connection.worktrees.contains(worktree_id));
+ }
+ assert_eq!(
+ share.active_replica_ids.len(),
+ share.guest_connection_ids.len(),
+ );
+ assert_eq!(
+ share.active_replica_ids,
+ share
+ .guest_connection_ids
+ .values()
+ .copied()
+ .collect::<HashSet<_>>(),
+ );
+ }
+ }
+
+ for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id {
+ for worktree_id in visible_worktree_ids {
+ let worktree = self.worktrees.get(worktree_id).unwrap();
+ assert!(worktree.collaborator_user_ids.contains(user_id));
+ }
+ }
+
+ for (channel_id, channel) in &self.channels {
+ for connection_id in &channel.connection_ids {
+ let connection = self.connections.get(connection_id).unwrap();
+ assert!(connection.channels.contains(channel_id));
+ }
+ }
}
}