1use crate::db::{ChannelId, ChannelRole, UserId};
2use anyhow::{Context as _, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::ConnectionId;
5use semantic_version::SemanticVersion;
6use serde::Serialize;
7use std::fmt;
8use tracing::instrument;
9
10#[derive(Default, Serialize)]
11pub struct ConnectionPool {
12 connections: BTreeMap<ConnectionId, Connection>,
13 connected_users: BTreeMap<UserId, ConnectedPrincipal>,
14 channels: ChannelPool,
15}
16
17#[derive(Default, Serialize)]
18struct ConnectedPrincipal {
19 connection_ids: HashSet<ConnectionId>,
20}
21
22#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
23pub struct ZedVersion(pub SemanticVersion);
24
25impl fmt::Display for ZedVersion {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 write!(f, "{}", self.0)
28 }
29}
30
31impl ZedVersion {
32 pub fn can_collaborate(&self) -> bool {
33 self.0 >= SemanticVersion::new(0, 157, 0)
34 }
35}
36
37#[derive(Serialize)]
38pub struct Connection {
39 pub user_id: UserId,
40 pub admin: bool,
41 pub zed_version: ZedVersion,
42}
43
44impl ConnectionPool {
45 pub fn reset(&mut self) {
46 self.connections.clear();
47 self.connected_users.clear();
48 self.channels.clear();
49 }
50
51 pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
52 self.connections.get(&connection_id)
53 }
54
55 #[instrument(skip(self))]
56 pub fn add_connection(
57 &mut self,
58 connection_id: ConnectionId,
59 user_id: UserId,
60 admin: bool,
61 zed_version: ZedVersion,
62 ) {
63 self.connections.insert(
64 connection_id,
65 Connection {
66 user_id,
67 admin,
68 zed_version,
69 },
70 );
71 let connected_user = self.connected_users.entry(user_id).or_default();
72 connected_user.connection_ids.insert(connection_id);
73 }
74
75 #[instrument(skip(self))]
76 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
77 let connection = self
78 .connections
79 .get_mut(&connection_id)
80 .context("no such connection")?;
81
82 let user_id = connection.user_id;
83
84 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
85 connected_user.connection_ids.remove(&connection_id);
86 if connected_user.connection_ids.is_empty() {
87 self.connected_users.remove(&user_id);
88 self.channels.remove_user(&user_id);
89 };
90 self.connections.remove(&connection_id).unwrap();
91 Ok(())
92 }
93
94 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
95 self.connections.values()
96 }
97
98 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
99 self.connected_users
100 .get(&user_id)
101 .into_iter()
102 .flat_map(|state| {
103 state
104 .connection_ids
105 .iter()
106 .flat_map(|cid| self.connections.get(cid))
107 })
108 }
109
110 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
111 self.connected_users
112 .get(&user_id)
113 .into_iter()
114 .flat_map(|state| &state.connection_ids)
115 .copied()
116 }
117
118 pub fn channel_user_ids(
119 &self,
120 channel_id: ChannelId,
121 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
122 self.channels.users_to_notify(channel_id)
123 }
124
125 pub fn channel_connection_ids(
126 &self,
127 channel_id: ChannelId,
128 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
129 self.channels
130 .users_to_notify(channel_id)
131 .flat_map(|(user_id, role)| {
132 self.user_connection_ids(user_id)
133 .map(move |connection_id| (connection_id, role))
134 })
135 }
136
137 pub fn subscribe_to_channel(
138 &mut self,
139 user_id: UserId,
140 channel_id: ChannelId,
141 role: ChannelRole,
142 ) {
143 self.channels.subscribe(user_id, channel_id, role);
144 }
145
146 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
147 self.channels.unsubscribe(user_id, channel_id);
148 }
149
150 pub fn is_user_online(&self, user_id: UserId) -> bool {
151 !self
152 .connected_users
153 .get(&user_id)
154 .unwrap_or(&Default::default())
155 .connection_ids
156 .is_empty()
157 }
158
159 #[cfg(test)]
160 pub fn check_invariants(&self) {
161 for (connection_id, connection) in &self.connections {
162 assert!(
163 self.connected_users
164 .get(&connection.user_id)
165 .unwrap()
166 .connection_ids
167 .contains(connection_id)
168 );
169 }
170
171 for (user_id, state) in &self.connected_users {
172 for connection_id in &state.connection_ids {
173 assert_eq!(
174 self.connections.get(connection_id).unwrap().user_id,
175 *user_id
176 );
177 }
178 }
179 }
180}
181
182#[derive(Default, Serialize)]
183pub struct ChannelPool {
184 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
185 by_channel: HashMap<ChannelId, HashSet<UserId>>,
186}
187
188impl ChannelPool {
189 pub fn clear(&mut self) {
190 self.by_user.clear();
191 self.by_channel.clear();
192 }
193
194 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
195 self.by_user
196 .entry(user_id)
197 .or_default()
198 .insert(channel_id, role);
199 self.by_channel
200 .entry(channel_id)
201 .or_default()
202 .insert(user_id);
203 }
204
205 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
206 if let Some(channels) = self.by_user.get_mut(user_id) {
207 channels.remove(channel_id);
208 if channels.is_empty() {
209 self.by_user.remove(user_id);
210 }
211 }
212 if let Some(users) = self.by_channel.get_mut(channel_id) {
213 users.remove(user_id);
214 if users.is_empty() {
215 self.by_channel.remove(channel_id);
216 }
217 }
218 }
219
220 pub fn remove_user(&mut self, user_id: &UserId) {
221 if let Some(channels) = self.by_user.remove(user_id) {
222 for channel_id in channels.keys() {
223 self.unsubscribe(user_id, channel_id)
224 }
225 }
226 }
227
228 pub fn users_to_notify(
229 &self,
230 channel_id: ChannelId,
231 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
232 self.by_channel
233 .get(&channel_id)
234 .into_iter()
235 .flat_map(move |users| {
236 users.iter().flat_map(move |user_id| {
237 Some((
238 *user_id,
239 self.by_user
240 .get(user_id)
241 .and_then(|channels| channels.get(&channel_id))
242 .copied()?,
243 ))
244 })
245 })
246 }
247}