1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use semantic_version::SemanticVersion;
6use serde::Serialize;
7use std::fmt;
8use tracing::instrument;
9
10#[derive(Default, Serialize)]
11pub struct ConnectionPool {
12 connections: BTreeMap<ConnectionId, Connection>,
13 connected_users: BTreeMap<UserId, ConnectedPrincipal>,
14 connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
15 channels: ChannelPool,
16 offline_dev_servers: HashSet<DevServerId>,
17}
18
19#[derive(Default, Serialize)]
20struct ConnectedPrincipal {
21 connection_ids: HashSet<ConnectionId>,
22}
23
24#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
25pub struct ZedVersion(pub SemanticVersion);
26
27impl fmt::Display for ZedVersion {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 write!(f, "{}", self.0)
30 }
31}
32
33impl ZedVersion {
34 pub fn can_collaborate(&self) -> bool {
35 self.0 >= SemanticVersion::new(0, 129, 2)
36 }
37
38 pub fn with_save_as() -> ZedVersion {
39 ZedVersion(SemanticVersion::new(0, 134, 0))
40 }
41}
42
43pub trait VersionedMessage {
44 fn required_host_version(&self) -> Option<ZedVersion> {
45 None
46 }
47}
48
49impl VersionedMessage for proto::SaveBuffer {
50 fn required_host_version(&self) -> Option<ZedVersion> {
51 if self.new_path.is_some() {
52 Some(ZedVersion::with_save_as())
53 } else {
54 None
55 }
56 }
57}
58
59impl VersionedMessage for proto::OpenNewBuffer {
60 fn required_host_version(&self) -> Option<ZedVersion> {
61 Some(ZedVersion::with_save_as())
62 }
63}
64
65#[derive(Serialize)]
66pub struct Connection {
67 pub principal_id: PrincipalId,
68 pub admin: bool,
69 pub zed_version: ZedVersion,
70}
71
72impl ConnectionPool {
73 pub fn reset(&mut self) {
74 self.connections.clear();
75 self.connected_users.clear();
76 self.connected_dev_servers.clear();
77 self.channels.clear();
78 }
79
80 pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
81 self.connections.get(&connection_id)
82 }
83
84 #[instrument(skip(self))]
85 pub fn add_connection(
86 &mut self,
87 connection_id: ConnectionId,
88 user_id: UserId,
89 admin: bool,
90 zed_version: ZedVersion,
91 ) {
92 self.connections.insert(
93 connection_id,
94 Connection {
95 principal_id: PrincipalId::UserId(user_id),
96 admin,
97 zed_version,
98 },
99 );
100 let connected_user = self.connected_users.entry(user_id).or_default();
101 connected_user.connection_ids.insert(connection_id);
102 }
103
104 pub fn add_dev_server(
105 &mut self,
106 connection_id: ConnectionId,
107 dev_server_id: DevServerId,
108 zed_version: ZedVersion,
109 ) {
110 self.connections.insert(
111 connection_id,
112 Connection {
113 principal_id: PrincipalId::DevServerId(dev_server_id),
114 admin: false,
115 zed_version,
116 },
117 );
118
119 self.connected_dev_servers
120 .insert(dev_server_id, connection_id);
121 }
122
123 #[instrument(skip(self))]
124 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
125 let connection = self
126 .connections
127 .get_mut(&connection_id)
128 .ok_or_else(|| anyhow!("no such connection"))?;
129
130 match connection.principal_id {
131 PrincipalId::UserId(user_id) => {
132 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
133 connected_user.connection_ids.remove(&connection_id);
134 if connected_user.connection_ids.is_empty() {
135 self.connected_users.remove(&user_id);
136 self.channels.remove_user(&user_id);
137 }
138 }
139 PrincipalId::DevServerId(dev_server_id) => {
140 self.connected_dev_servers.remove(&dev_server_id);
141 self.offline_dev_servers.remove(&dev_server_id);
142 }
143 }
144 self.connections.remove(&connection_id).unwrap();
145 Ok(())
146 }
147
148 pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
149 self.offline_dev_servers.insert(dev_server_id);
150 }
151
152 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
153 self.connections.values()
154 }
155
156 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
157 self.connected_users
158 .get(&user_id)
159 .into_iter()
160 .flat_map(|state| {
161 state
162 .connection_ids
163 .iter()
164 .flat_map(|cid| self.connections.get(cid))
165 })
166 }
167
168 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
169 self.connected_users
170 .get(&user_id)
171 .into_iter()
172 .flat_map(|state| &state.connection_ids)
173 .copied()
174 }
175
176 pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
177 if self.dev_server_connection_id(dev_server_id).is_some()
178 && !self.offline_dev_servers.contains(&dev_server_id)
179 {
180 proto::DevServerStatus::Online
181 } else {
182 proto::DevServerStatus::Offline
183 }
184 }
185
186 pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
187 self.connected_dev_servers.get(&dev_server_id).copied()
188 }
189
190 pub fn channel_user_ids(
191 &self,
192 channel_id: ChannelId,
193 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
194 self.channels.users_to_notify(channel_id)
195 }
196
197 pub fn channel_connection_ids(
198 &self,
199 channel_id: ChannelId,
200 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
201 self.channels
202 .users_to_notify(channel_id)
203 .flat_map(|(user_id, role)| {
204 self.user_connection_ids(user_id)
205 .map(move |connection_id| (connection_id, role))
206 })
207 }
208
209 pub fn subscribe_to_channel(
210 &mut self,
211 user_id: UserId,
212 channel_id: ChannelId,
213 role: ChannelRole,
214 ) {
215 self.channels.subscribe(user_id, channel_id, role);
216 }
217
218 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
219 self.channels.unsubscribe(user_id, channel_id);
220 }
221
222 pub fn is_user_online(&self, user_id: UserId) -> bool {
223 !self
224 .connected_users
225 .get(&user_id)
226 .unwrap_or(&Default::default())
227 .connection_ids
228 .is_empty()
229 }
230
231 #[cfg(test)]
232 pub fn check_invariants(&self) {
233 for (connection_id, connection) in &self.connections {
234 match &connection.principal_id {
235 PrincipalId::UserId(user_id) => {
236 assert!(self
237 .connected_users
238 .get(user_id)
239 .unwrap()
240 .connection_ids
241 .contains(connection_id));
242 }
243 PrincipalId::DevServerId(dev_server_id) => {
244 assert_eq!(
245 self.connected_dev_servers.get(&dev_server_id).unwrap(),
246 connection_id
247 );
248 }
249 }
250 }
251
252 for (user_id, state) in &self.connected_users {
253 for connection_id in &state.connection_ids {
254 assert_eq!(
255 self.connections.get(connection_id).unwrap().principal_id,
256 PrincipalId::UserId(*user_id)
257 );
258 }
259 }
260
261 for (dev_server_id, connection_id) in &self.connected_dev_servers {
262 assert_eq!(
263 self.connections.get(connection_id).unwrap().principal_id,
264 PrincipalId::DevServerId(*dev_server_id)
265 );
266 }
267 }
268}
269
270#[derive(Default, Serialize)]
271pub struct ChannelPool {
272 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
273 by_channel: HashMap<ChannelId, HashSet<UserId>>,
274}
275
276impl ChannelPool {
277 pub fn clear(&mut self) {
278 self.by_user.clear();
279 self.by_channel.clear();
280 }
281
282 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
283 self.by_user
284 .entry(user_id)
285 .or_default()
286 .insert(channel_id, role);
287 self.by_channel
288 .entry(channel_id)
289 .or_default()
290 .insert(user_id);
291 }
292
293 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
294 if let Some(channels) = self.by_user.get_mut(user_id) {
295 channels.remove(channel_id);
296 if channels.is_empty() {
297 self.by_user.remove(user_id);
298 }
299 }
300 if let Some(users) = self.by_channel.get_mut(channel_id) {
301 users.remove(user_id);
302 if users.is_empty() {
303 self.by_channel.remove(channel_id);
304 }
305 }
306 }
307
308 pub fn remove_user(&mut self, user_id: &UserId) {
309 if let Some(channels) = self.by_user.remove(&user_id) {
310 for channel_id in channels.keys() {
311 self.unsubscribe(user_id, &channel_id)
312 }
313 }
314 }
315
316 pub fn users_to_notify(
317 &self,
318 channel_id: ChannelId,
319 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
320 self.by_channel
321 .get(&channel_id)
322 .into_iter()
323 .flat_map(move |users| {
324 users.iter().flat_map(move |user_id| {
325 Some((
326 *user_id,
327 self.by_user
328 .get(user_id)
329 .and_then(|channels| channels.get(&channel_id))
330 .copied()?,
331 ))
332 })
333 })
334 }
335}