1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use semantic_version::SemanticVersion;
6use serde::Serialize;
7use std::fmt;
8use tracing::instrument;
9
10#[derive(Default, Serialize)]
11pub struct ConnectionPool {
12 connections: BTreeMap<ConnectionId, Connection>,
13 connected_users: BTreeMap<UserId, ConnectedPrincipal>,
14 connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
15 channels: ChannelPool,
16 offline_dev_servers: HashSet<DevServerId>,
17}
18
19#[derive(Default, Serialize)]
20struct ConnectedPrincipal {
21 connection_ids: HashSet<ConnectionId>,
22}
23
24#[derive(Debug, Serialize)]
25pub struct ZedVersion(pub SemanticVersion);
26
27impl fmt::Display for ZedVersion {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 write!(f, "{}", self.0)
30 }
31}
32
33impl ZedVersion {
34 pub fn can_collaborate(&self) -> bool {
35 self.0 >= SemanticVersion::new(0, 129, 2)
36 }
37}
38
39#[derive(Serialize)]
40pub struct Connection {
41 pub principal_id: PrincipalId,
42 pub admin: bool,
43 pub zed_version: ZedVersion,
44}
45
46impl ConnectionPool {
47 pub fn reset(&mut self) {
48 self.connections.clear();
49 self.connected_users.clear();
50 self.channels.clear();
51 }
52
53 #[instrument(skip(self))]
54 pub fn add_connection(
55 &mut self,
56 connection_id: ConnectionId,
57 user_id: UserId,
58 admin: bool,
59 zed_version: ZedVersion,
60 ) {
61 self.connections.insert(
62 connection_id,
63 Connection {
64 principal_id: PrincipalId::UserId(user_id),
65 admin,
66 zed_version,
67 },
68 );
69 let connected_user = self.connected_users.entry(user_id).or_default();
70 connected_user.connection_ids.insert(connection_id);
71 }
72
73 pub fn add_dev_server(
74 &mut self,
75 connection_id: ConnectionId,
76 dev_server_id: DevServerId,
77 zed_version: ZedVersion,
78 ) {
79 self.connections.insert(
80 connection_id,
81 Connection {
82 principal_id: PrincipalId::DevServerId(dev_server_id),
83 admin: false,
84 zed_version,
85 },
86 );
87
88 self.connected_dev_servers
89 .insert(dev_server_id, connection_id);
90 }
91
92 #[instrument(skip(self))]
93 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
94 let connection = self
95 .connections
96 .get_mut(&connection_id)
97 .ok_or_else(|| anyhow!("no such connection"))?;
98
99 match connection.principal_id {
100 PrincipalId::UserId(user_id) => {
101 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
102 connected_user.connection_ids.remove(&connection_id);
103 if connected_user.connection_ids.is_empty() {
104 self.connected_users.remove(&user_id);
105 self.channels.remove_user(&user_id);
106 }
107 }
108 PrincipalId::DevServerId(dev_server_id) => {
109 self.connected_dev_servers.remove(&dev_server_id);
110 self.offline_dev_servers.remove(&dev_server_id);
111 }
112 }
113 self.connections.remove(&connection_id).unwrap();
114 Ok(())
115 }
116
117 pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
118 self.offline_dev_servers.insert(dev_server_id);
119 }
120
121 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
122 self.connections.values()
123 }
124
125 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
126 self.connected_users
127 .get(&user_id)
128 .into_iter()
129 .flat_map(|state| {
130 state
131 .connection_ids
132 .iter()
133 .flat_map(|cid| self.connections.get(cid))
134 })
135 }
136
137 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
138 self.connected_users
139 .get(&user_id)
140 .into_iter()
141 .flat_map(|state| &state.connection_ids)
142 .copied()
143 }
144
145 pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
146 if self.dev_server_connection_id(dev_server_id).is_some()
147 && !self.offline_dev_servers.contains(&dev_server_id)
148 {
149 proto::DevServerStatus::Online
150 } else {
151 proto::DevServerStatus::Offline
152 }
153 }
154
155 pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
156 self.connected_dev_servers.get(&dev_server_id).copied()
157 }
158
159 pub fn channel_user_ids(
160 &self,
161 channel_id: ChannelId,
162 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
163 self.channels.users_to_notify(channel_id)
164 }
165
166 pub fn channel_connection_ids(
167 &self,
168 channel_id: ChannelId,
169 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
170 self.channels
171 .users_to_notify(channel_id)
172 .flat_map(|(user_id, role)| {
173 self.user_connection_ids(user_id)
174 .map(move |connection_id| (connection_id, role))
175 })
176 }
177
178 pub fn subscribe_to_channel(
179 &mut self,
180 user_id: UserId,
181 channel_id: ChannelId,
182 role: ChannelRole,
183 ) {
184 self.channels.subscribe(user_id, channel_id, role);
185 }
186
187 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
188 self.channels.unsubscribe(user_id, channel_id);
189 }
190
191 pub fn is_user_online(&self, user_id: UserId) -> bool {
192 !self
193 .connected_users
194 .get(&user_id)
195 .unwrap_or(&Default::default())
196 .connection_ids
197 .is_empty()
198 }
199
200 #[cfg(test)]
201 pub fn check_invariants(&self) {
202 for (connection_id, connection) in &self.connections {
203 match &connection.principal_id {
204 PrincipalId::UserId(user_id) => {
205 assert!(self
206 .connected_users
207 .get(user_id)
208 .unwrap()
209 .connection_ids
210 .contains(connection_id));
211 }
212 PrincipalId::DevServerId(dev_server_id) => {
213 assert_eq!(
214 self.connected_dev_servers.get(&dev_server_id).unwrap(),
215 connection_id
216 );
217 }
218 }
219 }
220
221 for (user_id, state) in &self.connected_users {
222 for connection_id in &state.connection_ids {
223 assert_eq!(
224 self.connections.get(connection_id).unwrap().principal_id,
225 PrincipalId::UserId(*user_id)
226 );
227 }
228 }
229
230 for (dev_server_id, connection_id) in &self.connected_dev_servers {
231 assert_eq!(
232 self.connections.get(connection_id).unwrap().principal_id,
233 PrincipalId::DevServerId(*dev_server_id)
234 );
235 }
236 }
237}
238
239#[derive(Default, Serialize)]
240pub struct ChannelPool {
241 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
242 by_channel: HashMap<ChannelId, HashSet<UserId>>,
243}
244
245impl ChannelPool {
246 pub fn clear(&mut self) {
247 self.by_user.clear();
248 self.by_channel.clear();
249 }
250
251 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
252 self.by_user
253 .entry(user_id)
254 .or_default()
255 .insert(channel_id, role);
256 self.by_channel
257 .entry(channel_id)
258 .or_default()
259 .insert(user_id);
260 }
261
262 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
263 if let Some(channels) = self.by_user.get_mut(user_id) {
264 channels.remove(channel_id);
265 if channels.is_empty() {
266 self.by_user.remove(user_id);
267 }
268 }
269 if let Some(users) = self.by_channel.get_mut(channel_id) {
270 users.remove(user_id);
271 if users.is_empty() {
272 self.by_channel.remove(channel_id);
273 }
274 }
275 }
276
277 pub fn remove_user(&mut self, user_id: &UserId) {
278 if let Some(channels) = self.by_user.remove(&user_id) {
279 for channel_id in channels.keys() {
280 self.unsubscribe(user_id, &channel_id)
281 }
282 }
283 }
284
285 pub fn users_to_notify(
286 &self,
287 channel_id: ChannelId,
288 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
289 self.by_channel
290 .get(&channel_id)
291 .into_iter()
292 .flat_map(move |users| {
293 users.iter().flat_map(move |user_id| {
294 Some((
295 *user_id,
296 self.by_user
297 .get(user_id)
298 .and_then(|channels| channels.get(&channel_id))
299 .copied()?,
300 ))
301 })
302 })
303 }
304}