1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use semantic_version::SemanticVersion;
6use serde::Serialize;
7use std::fmt;
8use tracing::instrument;
9
10#[derive(Default, Serialize)]
11pub struct ConnectionPool {
12 connections: BTreeMap<ConnectionId, Connection>,
13 connected_users: BTreeMap<UserId, ConnectedPrincipal>,
14 connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
15 channels: ChannelPool,
16 offline_dev_servers: HashSet<DevServerId>,
17}
18
19#[derive(Default, Serialize)]
20struct ConnectedPrincipal {
21 connection_ids: HashSet<ConnectionId>,
22}
23
24#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
25pub struct ZedVersion(pub SemanticVersion);
26
27impl fmt::Display for ZedVersion {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 write!(f, "{}", self.0)
30 }
31}
32
33impl ZedVersion {
34 pub fn can_collaborate(&self) -> bool {
35 self.0 >= SemanticVersion::new(0, 129, 2)
36 }
37
38 pub fn with_save_as() -> ZedVersion {
39 ZedVersion(SemanticVersion::new(0, 134, 0))
40 }
41
42 pub fn with_list_directory() -> ZedVersion {
43 ZedVersion(SemanticVersion::new(0, 145, 0))
44 }
45}
46
47pub trait VersionedMessage {
48 fn required_host_version(&self) -> Option<ZedVersion> {
49 None
50 }
51}
52
53impl VersionedMessage for proto::SaveBuffer {
54 fn required_host_version(&self) -> Option<ZedVersion> {
55 if self.new_path.is_some() {
56 Some(ZedVersion::with_save_as())
57 } else {
58 None
59 }
60 }
61}
62
63impl VersionedMessage for proto::OpenNewBuffer {
64 fn required_host_version(&self) -> Option<ZedVersion> {
65 Some(ZedVersion::with_save_as())
66 }
67}
68
69#[derive(Serialize)]
70pub struct Connection {
71 pub principal_id: PrincipalId,
72 pub admin: bool,
73 pub zed_version: ZedVersion,
74}
75
76impl ConnectionPool {
77 pub fn reset(&mut self) {
78 self.connections.clear();
79 self.connected_users.clear();
80 self.connected_dev_servers.clear();
81 self.channels.clear();
82 }
83
84 pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
85 self.connections.get(&connection_id)
86 }
87
88 #[instrument(skip(self))]
89 pub fn add_connection(
90 &mut self,
91 connection_id: ConnectionId,
92 user_id: UserId,
93 admin: bool,
94 zed_version: ZedVersion,
95 ) {
96 self.connections.insert(
97 connection_id,
98 Connection {
99 principal_id: PrincipalId::UserId(user_id),
100 admin,
101 zed_version,
102 },
103 );
104 let connected_user = self.connected_users.entry(user_id).or_default();
105 connected_user.connection_ids.insert(connection_id);
106 }
107
108 pub fn add_dev_server(
109 &mut self,
110 connection_id: ConnectionId,
111 dev_server_id: DevServerId,
112 zed_version: ZedVersion,
113 ) {
114 self.connections.insert(
115 connection_id,
116 Connection {
117 principal_id: PrincipalId::DevServerId(dev_server_id),
118 admin: false,
119 zed_version,
120 },
121 );
122
123 self.connected_dev_servers
124 .insert(dev_server_id, connection_id);
125 }
126
127 #[instrument(skip(self))]
128 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
129 let connection = self
130 .connections
131 .get_mut(&connection_id)
132 .ok_or_else(|| anyhow!("no such connection"))?;
133
134 match connection.principal_id {
135 PrincipalId::UserId(user_id) => {
136 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
137 connected_user.connection_ids.remove(&connection_id);
138 if connected_user.connection_ids.is_empty() {
139 self.connected_users.remove(&user_id);
140 self.channels.remove_user(&user_id);
141 }
142 }
143 PrincipalId::DevServerId(dev_server_id) => {
144 self.connected_dev_servers.remove(&dev_server_id);
145 self.offline_dev_servers.remove(&dev_server_id);
146 }
147 }
148 self.connections.remove(&connection_id).unwrap();
149 Ok(())
150 }
151
152 pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
153 self.offline_dev_servers.insert(dev_server_id);
154 }
155
156 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
157 self.connections.values()
158 }
159
160 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
161 self.connected_users
162 .get(&user_id)
163 .into_iter()
164 .flat_map(|state| {
165 state
166 .connection_ids
167 .iter()
168 .flat_map(|cid| self.connections.get(cid))
169 })
170 }
171
172 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
173 self.connected_users
174 .get(&user_id)
175 .into_iter()
176 .flat_map(|state| &state.connection_ids)
177 .copied()
178 }
179
180 pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
181 if self.dev_server_connection_id(dev_server_id).is_some()
182 && !self.offline_dev_servers.contains(&dev_server_id)
183 {
184 proto::DevServerStatus::Online
185 } else {
186 proto::DevServerStatus::Offline
187 }
188 }
189
190 pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
191 self.connected_dev_servers.get(&dev_server_id).copied()
192 }
193
194 pub fn dev_server_connection_id_supporting(
195 &self,
196 dev_server_id: DevServerId,
197 required: ZedVersion,
198 ) -> Result<ConnectionId> {
199 match self.connected_dev_servers.get(&dev_server_id) {
200 Some(cid) if self.connections[cid].zed_version >= required => Ok(*cid),
201 Some(_) => Err(anyhow!(proto::ErrorCode::RemoteUpgradeRequired)),
202 None => Err(anyhow!(proto::ErrorCode::DevServerOffline)),
203 }
204 }
205
206 pub fn channel_user_ids(
207 &self,
208 channel_id: ChannelId,
209 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
210 self.channels.users_to_notify(channel_id)
211 }
212
213 pub fn channel_connection_ids(
214 &self,
215 channel_id: ChannelId,
216 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
217 self.channels
218 .users_to_notify(channel_id)
219 .flat_map(|(user_id, role)| {
220 self.user_connection_ids(user_id)
221 .map(move |connection_id| (connection_id, role))
222 })
223 }
224
225 pub fn subscribe_to_channel(
226 &mut self,
227 user_id: UserId,
228 channel_id: ChannelId,
229 role: ChannelRole,
230 ) {
231 self.channels.subscribe(user_id, channel_id, role);
232 }
233
234 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
235 self.channels.unsubscribe(user_id, channel_id);
236 }
237
238 pub fn is_user_online(&self, user_id: UserId) -> bool {
239 !self
240 .connected_users
241 .get(&user_id)
242 .unwrap_or(&Default::default())
243 .connection_ids
244 .is_empty()
245 }
246
247 #[cfg(test)]
248 pub fn check_invariants(&self) {
249 for (connection_id, connection) in &self.connections {
250 match &connection.principal_id {
251 PrincipalId::UserId(user_id) => {
252 assert!(self
253 .connected_users
254 .get(user_id)
255 .unwrap()
256 .connection_ids
257 .contains(connection_id));
258 }
259 PrincipalId::DevServerId(dev_server_id) => {
260 assert_eq!(
261 self.connected_dev_servers.get(&dev_server_id).unwrap(),
262 connection_id
263 );
264 }
265 }
266 }
267
268 for (user_id, state) in &self.connected_users {
269 for connection_id in &state.connection_ids {
270 assert_eq!(
271 self.connections.get(connection_id).unwrap().principal_id,
272 PrincipalId::UserId(*user_id)
273 );
274 }
275 }
276
277 for (dev_server_id, connection_id) in &self.connected_dev_servers {
278 assert_eq!(
279 self.connections.get(connection_id).unwrap().principal_id,
280 PrincipalId::DevServerId(*dev_server_id)
281 );
282 }
283 }
284}
285
286#[derive(Default, Serialize)]
287pub struct ChannelPool {
288 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
289 by_channel: HashMap<ChannelId, HashSet<UserId>>,
290}
291
292impl ChannelPool {
293 pub fn clear(&mut self) {
294 self.by_user.clear();
295 self.by_channel.clear();
296 }
297
298 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
299 self.by_user
300 .entry(user_id)
301 .or_default()
302 .insert(channel_id, role);
303 self.by_channel
304 .entry(channel_id)
305 .or_default()
306 .insert(user_id);
307 }
308
309 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
310 if let Some(channels) = self.by_user.get_mut(user_id) {
311 channels.remove(channel_id);
312 if channels.is_empty() {
313 self.by_user.remove(user_id);
314 }
315 }
316 if let Some(users) = self.by_channel.get_mut(channel_id) {
317 users.remove(user_id);
318 if users.is_empty() {
319 self.by_channel.remove(channel_id);
320 }
321 }
322 }
323
324 pub fn remove_user(&mut self, user_id: &UserId) {
325 if let Some(channels) = self.by_user.remove(&user_id) {
326 for channel_id in channels.keys() {
327 self.unsubscribe(user_id, &channel_id)
328 }
329 }
330 }
331
332 pub fn users_to_notify(
333 &self,
334 channel_id: ChannelId,
335 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
336 self.by_channel
337 .get(&channel_id)
338 .into_iter()
339 .flat_map(move |users| {
340 users.iter().flat_map(move |user_id| {
341 Some((
342 *user_id,
343 self.by_user
344 .get(user_id)
345 .and_then(|channels| channels.get(&channel_id))
346 .copied()?,
347 ))
348 })
349 })
350 }
351}