1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use semantic_version::SemanticVersion;
6use serde::Serialize;
7use std::fmt;
8use tracing::instrument;
9
10#[derive(Default, Serialize)]
11pub struct ConnectionPool {
12 connections: BTreeMap<ConnectionId, Connection>,
13 connected_users: BTreeMap<UserId, ConnectedPrincipal>,
14 connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
15 channels: ChannelPool,
16 offline_dev_servers: HashSet<DevServerId>,
17}
18
19#[derive(Default, Serialize)]
20struct ConnectedPrincipal {
21 connection_ids: HashSet<ConnectionId>,
22}
23
24#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
25pub struct ZedVersion(pub SemanticVersion);
26
27impl fmt::Display for ZedVersion {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 write!(f, "{}", self.0)
30 }
31}
32
33impl ZedVersion {
34 pub fn can_collaborate(&self) -> bool {
35 self.0 >= SemanticVersion::new(0, 151, 0)
36 }
37}
38
39#[derive(Serialize)]
40pub struct Connection {
41 pub principal_id: PrincipalId,
42 pub admin: bool,
43 pub zed_version: ZedVersion,
44}
45
46impl ConnectionPool {
47 pub fn reset(&mut self) {
48 self.connections.clear();
49 self.connected_users.clear();
50 self.connected_dev_servers.clear();
51 self.channels.clear();
52 }
53
54 pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
55 self.connections.get(&connection_id)
56 }
57
58 #[instrument(skip(self))]
59 pub fn add_connection(
60 &mut self,
61 connection_id: ConnectionId,
62 user_id: UserId,
63 admin: bool,
64 zed_version: ZedVersion,
65 ) {
66 self.connections.insert(
67 connection_id,
68 Connection {
69 principal_id: PrincipalId::UserId(user_id),
70 admin,
71 zed_version,
72 },
73 );
74 let connected_user = self.connected_users.entry(user_id).or_default();
75 connected_user.connection_ids.insert(connection_id);
76 }
77
78 pub fn add_dev_server(
79 &mut self,
80 connection_id: ConnectionId,
81 dev_server_id: DevServerId,
82 zed_version: ZedVersion,
83 ) {
84 self.connections.insert(
85 connection_id,
86 Connection {
87 principal_id: PrincipalId::DevServerId(dev_server_id),
88 admin: false,
89 zed_version,
90 },
91 );
92
93 self.connected_dev_servers
94 .insert(dev_server_id, connection_id);
95 }
96
97 #[instrument(skip(self))]
98 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
99 let connection = self
100 .connections
101 .get_mut(&connection_id)
102 .ok_or_else(|| anyhow!("no such connection"))?;
103
104 match connection.principal_id {
105 PrincipalId::UserId(user_id) => {
106 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
107 connected_user.connection_ids.remove(&connection_id);
108 if connected_user.connection_ids.is_empty() {
109 self.connected_users.remove(&user_id);
110 self.channels.remove_user(&user_id);
111 }
112 }
113 PrincipalId::DevServerId(dev_server_id) => {
114 self.connected_dev_servers.remove(&dev_server_id);
115 self.offline_dev_servers.remove(&dev_server_id);
116 }
117 }
118 self.connections.remove(&connection_id).unwrap();
119 Ok(())
120 }
121
122 pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
123 self.offline_dev_servers.insert(dev_server_id);
124 }
125
126 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
127 self.connections.values()
128 }
129
130 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
131 self.connected_users
132 .get(&user_id)
133 .into_iter()
134 .flat_map(|state| {
135 state
136 .connection_ids
137 .iter()
138 .flat_map(|cid| self.connections.get(cid))
139 })
140 }
141
142 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
143 self.connected_users
144 .get(&user_id)
145 .into_iter()
146 .flat_map(|state| &state.connection_ids)
147 .copied()
148 }
149
150 pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
151 if self.dev_server_connection_id(dev_server_id).is_some()
152 && !self.offline_dev_servers.contains(&dev_server_id)
153 {
154 proto::DevServerStatus::Online
155 } else {
156 proto::DevServerStatus::Offline
157 }
158 }
159
160 pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
161 self.connected_dev_servers.get(&dev_server_id).copied()
162 }
163
164 pub fn online_dev_server_connection_id(
165 &self,
166 dev_server_id: DevServerId,
167 ) -> Result<ConnectionId> {
168 match self.connected_dev_servers.get(&dev_server_id) {
169 Some(cid) => Ok(*cid),
170 None => Err(anyhow!(proto::ErrorCode::DevServerOffline)),
171 }
172 }
173
174 pub fn dev_server_connection_id_supporting(
175 &self,
176 dev_server_id: DevServerId,
177 required: ZedVersion,
178 ) -> Result<ConnectionId> {
179 match self.connected_dev_servers.get(&dev_server_id) {
180 Some(cid) if self.connections[cid].zed_version >= required => Ok(*cid),
181 Some(_) => Err(anyhow!(proto::ErrorCode::RemoteUpgradeRequired)),
182 None => Err(anyhow!(proto::ErrorCode::DevServerOffline)),
183 }
184 }
185
186 pub fn channel_user_ids(
187 &self,
188 channel_id: ChannelId,
189 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
190 self.channels.users_to_notify(channel_id)
191 }
192
193 pub fn channel_connection_ids(
194 &self,
195 channel_id: ChannelId,
196 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
197 self.channels
198 .users_to_notify(channel_id)
199 .flat_map(|(user_id, role)| {
200 self.user_connection_ids(user_id)
201 .map(move |connection_id| (connection_id, role))
202 })
203 }
204
205 pub fn subscribe_to_channel(
206 &mut self,
207 user_id: UserId,
208 channel_id: ChannelId,
209 role: ChannelRole,
210 ) {
211 self.channels.subscribe(user_id, channel_id, role);
212 }
213
214 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
215 self.channels.unsubscribe(user_id, channel_id);
216 }
217
218 pub fn is_user_online(&self, user_id: UserId) -> bool {
219 !self
220 .connected_users
221 .get(&user_id)
222 .unwrap_or(&Default::default())
223 .connection_ids
224 .is_empty()
225 }
226
227 #[cfg(test)]
228 pub fn check_invariants(&self) {
229 for (connection_id, connection) in &self.connections {
230 match &connection.principal_id {
231 PrincipalId::UserId(user_id) => {
232 assert!(self
233 .connected_users
234 .get(user_id)
235 .unwrap()
236 .connection_ids
237 .contains(connection_id));
238 }
239 PrincipalId::DevServerId(dev_server_id) => {
240 assert_eq!(
241 self.connected_dev_servers.get(dev_server_id).unwrap(),
242 connection_id
243 );
244 }
245 }
246 }
247
248 for (user_id, state) in &self.connected_users {
249 for connection_id in &state.connection_ids {
250 assert_eq!(
251 self.connections.get(connection_id).unwrap().principal_id,
252 PrincipalId::UserId(*user_id)
253 );
254 }
255 }
256
257 for (dev_server_id, connection_id) in &self.connected_dev_servers {
258 assert_eq!(
259 self.connections.get(connection_id).unwrap().principal_id,
260 PrincipalId::DevServerId(*dev_server_id)
261 );
262 }
263 }
264}
265
266#[derive(Default, Serialize)]
267pub struct ChannelPool {
268 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
269 by_channel: HashMap<ChannelId, HashSet<UserId>>,
270}
271
272impl ChannelPool {
273 pub fn clear(&mut self) {
274 self.by_user.clear();
275 self.by_channel.clear();
276 }
277
278 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
279 self.by_user
280 .entry(user_id)
281 .or_default()
282 .insert(channel_id, role);
283 self.by_channel
284 .entry(channel_id)
285 .or_default()
286 .insert(user_id);
287 }
288
289 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
290 if let Some(channels) = self.by_user.get_mut(user_id) {
291 channels.remove(channel_id);
292 if channels.is_empty() {
293 self.by_user.remove(user_id);
294 }
295 }
296 if let Some(users) = self.by_channel.get_mut(channel_id) {
297 users.remove(user_id);
298 if users.is_empty() {
299 self.by_channel.remove(channel_id);
300 }
301 }
302 }
303
304 pub fn remove_user(&mut self, user_id: &UserId) {
305 if let Some(channels) = self.by_user.remove(user_id) {
306 for channel_id in channels.keys() {
307 self.unsubscribe(user_id, channel_id)
308 }
309 }
310 }
311
312 pub fn users_to_notify(
313 &self,
314 channel_id: ChannelId,
315 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
316 self.by_channel
317 .get(&channel_id)
318 .into_iter()
319 .flat_map(move |users| {
320 users.iter().flat_map(move |user_id| {
321 Some((
322 *user_id,
323 self.by_user
324 .get(user_id)
325 .and_then(|channels| channels.get(&channel_id))
326 .copied()?,
327 ))
328 })
329 })
330 }
331}