1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use semantic_version::SemanticVersion;
6use serde::Serialize;
7use std::fmt;
8use tracing::instrument;
9
10#[derive(Default, Serialize)]
11pub struct ConnectionPool {
12 connections: BTreeMap<ConnectionId, Connection>,
13 connected_users: BTreeMap<UserId, ConnectedPrincipal>,
14 connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
15 channels: ChannelPool,
16 offline_dev_servers: HashSet<DevServerId>,
17}
18
19#[derive(Default, Serialize)]
20struct ConnectedPrincipal {
21 connection_ids: HashSet<ConnectionId>,
22}
23
24#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
25pub struct ZedVersion(pub SemanticVersion);
26
27impl fmt::Display for ZedVersion {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 write!(f, "{}", self.0)
30 }
31}
32
33impl ZedVersion {
34 pub fn can_collaborate(&self) -> bool {
35 self.0 >= SemanticVersion::new(0, 134, 0)
36 }
37
38 pub fn with_list_directory() -> ZedVersion {
39 ZedVersion(SemanticVersion::new(0, 145, 0))
40 }
41
42 pub fn with_search_candidates() -> ZedVersion {
43 ZedVersion(SemanticVersion::new(0, 151, 0))
44 }
45}
46
47#[derive(Serialize)]
48pub struct Connection {
49 pub principal_id: PrincipalId,
50 pub admin: bool,
51 pub zed_version: ZedVersion,
52}
53
54impl ConnectionPool {
55 pub fn reset(&mut self) {
56 self.connections.clear();
57 self.connected_users.clear();
58 self.connected_dev_servers.clear();
59 self.channels.clear();
60 }
61
62 pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
63 self.connections.get(&connection_id)
64 }
65
66 #[instrument(skip(self))]
67 pub fn add_connection(
68 &mut self,
69 connection_id: ConnectionId,
70 user_id: UserId,
71 admin: bool,
72 zed_version: ZedVersion,
73 ) {
74 self.connections.insert(
75 connection_id,
76 Connection {
77 principal_id: PrincipalId::UserId(user_id),
78 admin,
79 zed_version,
80 },
81 );
82 let connected_user = self.connected_users.entry(user_id).or_default();
83 connected_user.connection_ids.insert(connection_id);
84 }
85
86 pub fn add_dev_server(
87 &mut self,
88 connection_id: ConnectionId,
89 dev_server_id: DevServerId,
90 zed_version: ZedVersion,
91 ) {
92 self.connections.insert(
93 connection_id,
94 Connection {
95 principal_id: PrincipalId::DevServerId(dev_server_id),
96 admin: false,
97 zed_version,
98 },
99 );
100
101 self.connected_dev_servers
102 .insert(dev_server_id, connection_id);
103 }
104
105 #[instrument(skip(self))]
106 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
107 let connection = self
108 .connections
109 .get_mut(&connection_id)
110 .ok_or_else(|| anyhow!("no such connection"))?;
111
112 match connection.principal_id {
113 PrincipalId::UserId(user_id) => {
114 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
115 connected_user.connection_ids.remove(&connection_id);
116 if connected_user.connection_ids.is_empty() {
117 self.connected_users.remove(&user_id);
118 self.channels.remove_user(&user_id);
119 }
120 }
121 PrincipalId::DevServerId(dev_server_id) => {
122 self.connected_dev_servers.remove(&dev_server_id);
123 self.offline_dev_servers.remove(&dev_server_id);
124 }
125 }
126 self.connections.remove(&connection_id).unwrap();
127 Ok(())
128 }
129
130 pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
131 self.offline_dev_servers.insert(dev_server_id);
132 }
133
134 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
135 self.connections.values()
136 }
137
138 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
139 self.connected_users
140 .get(&user_id)
141 .into_iter()
142 .flat_map(|state| {
143 state
144 .connection_ids
145 .iter()
146 .flat_map(|cid| self.connections.get(cid))
147 })
148 }
149
150 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
151 self.connected_users
152 .get(&user_id)
153 .into_iter()
154 .flat_map(|state| &state.connection_ids)
155 .copied()
156 }
157
158 pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
159 if self.dev_server_connection_id(dev_server_id).is_some()
160 && !self.offline_dev_servers.contains(&dev_server_id)
161 {
162 proto::DevServerStatus::Online
163 } else {
164 proto::DevServerStatus::Offline
165 }
166 }
167
168 pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
169 self.connected_dev_servers.get(&dev_server_id).copied()
170 }
171
172 pub fn dev_server_connection_id_supporting(
173 &self,
174 dev_server_id: DevServerId,
175 required: ZedVersion,
176 ) -> Result<ConnectionId> {
177 match self.connected_dev_servers.get(&dev_server_id) {
178 Some(cid) if self.connections[cid].zed_version >= required => Ok(*cid),
179 Some(_) => Err(anyhow!(proto::ErrorCode::RemoteUpgradeRequired)),
180 None => Err(anyhow!(proto::ErrorCode::DevServerOffline)),
181 }
182 }
183
184 pub fn channel_user_ids(
185 &self,
186 channel_id: ChannelId,
187 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
188 self.channels.users_to_notify(channel_id)
189 }
190
191 pub fn channel_connection_ids(
192 &self,
193 channel_id: ChannelId,
194 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
195 self.channels
196 .users_to_notify(channel_id)
197 .flat_map(|(user_id, role)| {
198 self.user_connection_ids(user_id)
199 .map(move |connection_id| (connection_id, role))
200 })
201 }
202
203 pub fn subscribe_to_channel(
204 &mut self,
205 user_id: UserId,
206 channel_id: ChannelId,
207 role: ChannelRole,
208 ) {
209 self.channels.subscribe(user_id, channel_id, role);
210 }
211
212 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
213 self.channels.unsubscribe(user_id, channel_id);
214 }
215
216 pub fn is_user_online(&self, user_id: UserId) -> bool {
217 !self
218 .connected_users
219 .get(&user_id)
220 .unwrap_or(&Default::default())
221 .connection_ids
222 .is_empty()
223 }
224
225 #[cfg(test)]
226 pub fn check_invariants(&self) {
227 for (connection_id, connection) in &self.connections {
228 match &connection.principal_id {
229 PrincipalId::UserId(user_id) => {
230 assert!(self
231 .connected_users
232 .get(user_id)
233 .unwrap()
234 .connection_ids
235 .contains(connection_id));
236 }
237 PrincipalId::DevServerId(dev_server_id) => {
238 assert_eq!(
239 self.connected_dev_servers.get(dev_server_id).unwrap(),
240 connection_id
241 );
242 }
243 }
244 }
245
246 for (user_id, state) in &self.connected_users {
247 for connection_id in &state.connection_ids {
248 assert_eq!(
249 self.connections.get(connection_id).unwrap().principal_id,
250 PrincipalId::UserId(*user_id)
251 );
252 }
253 }
254
255 for (dev_server_id, connection_id) in &self.connected_dev_servers {
256 assert_eq!(
257 self.connections.get(connection_id).unwrap().principal_id,
258 PrincipalId::DevServerId(*dev_server_id)
259 );
260 }
261 }
262}
263
264#[derive(Default, Serialize)]
265pub struct ChannelPool {
266 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
267 by_channel: HashMap<ChannelId, HashSet<UserId>>,
268}
269
270impl ChannelPool {
271 pub fn clear(&mut self) {
272 self.by_user.clear();
273 self.by_channel.clear();
274 }
275
276 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
277 self.by_user
278 .entry(user_id)
279 .or_default()
280 .insert(channel_id, role);
281 self.by_channel
282 .entry(channel_id)
283 .or_default()
284 .insert(user_id);
285 }
286
287 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
288 if let Some(channels) = self.by_user.get_mut(user_id) {
289 channels.remove(channel_id);
290 if channels.is_empty() {
291 self.by_user.remove(user_id);
292 }
293 }
294 if let Some(users) = self.by_channel.get_mut(channel_id) {
295 users.remove(user_id);
296 if users.is_empty() {
297 self.by_channel.remove(channel_id);
298 }
299 }
300 }
301
302 pub fn remove_user(&mut self, user_id: &UserId) {
303 if let Some(channels) = self.by_user.remove(user_id) {
304 for channel_id in channels.keys() {
305 self.unsubscribe(user_id, channel_id)
306 }
307 }
308 }
309
310 pub fn users_to_notify(
311 &self,
312 channel_id: ChannelId,
313 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
314 self.by_channel
315 .get(&channel_id)
316 .into_iter()
317 .flat_map(move |users| {
318 users.iter().flat_map(move |user_id| {
319 Some((
320 *user_id,
321 self.by_user
322 .get(user_id)
323 .and_then(|channels| channels.get(&channel_id))
324 .copied()?,
325 ))
326 })
327 })
328 }
329}