1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use semantic_version::SemanticVersion;
6use serde::Serialize;
7use std::fmt;
8use tracing::instrument;
9
10#[derive(Default, Serialize)]
11pub struct ConnectionPool {
12 connections: BTreeMap<ConnectionId, Connection>,
13 connected_users: BTreeMap<UserId, ConnectedPrincipal>,
14 connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
15 channels: ChannelPool,
16 offline_dev_servers: HashSet<DevServerId>,
17}
18
19#[derive(Default, Serialize)]
20struct ConnectedPrincipal {
21 connection_ids: HashSet<ConnectionId>,
22}
23
24#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
25pub struct ZedVersion(pub SemanticVersion);
26
27impl fmt::Display for ZedVersion {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 write!(f, "{}", self.0)
30 }
31}
32
33impl ZedVersion {
34 pub fn can_collaborate(&self) -> bool {
35 self.0 >= SemanticVersion::new(0, 129, 2)
36 }
37
38 pub fn with_save_as() -> ZedVersion {
39 ZedVersion(SemanticVersion::new(0, 134, 0))
40 }
41}
42
43pub trait VersionedMessage {
44 fn required_host_version(&self) -> Option<ZedVersion> {
45 None
46 }
47}
48
49impl VersionedMessage for proto::SaveBuffer {
50 fn required_host_version(&self) -> Option<ZedVersion> {
51 if self.new_path.is_some() {
52 Some(ZedVersion::with_save_as())
53 } else {
54 None
55 }
56 }
57}
58
59impl VersionedMessage for proto::OpenNewBuffer {
60 fn required_host_version(&self) -> Option<ZedVersion> {
61 Some(ZedVersion::with_save_as())
62 }
63}
64
65#[derive(Serialize)]
66pub struct Connection {
67 pub principal_id: PrincipalId,
68 pub admin: bool,
69 pub zed_version: ZedVersion,
70}
71
72impl ConnectionPool {
73 pub fn reset(&mut self) {
74 self.connections.clear();
75 self.connected_users.clear();
76 self.channels.clear();
77 }
78
79 pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
80 self.connections.get(&connection_id)
81 }
82
83 #[instrument(skip(self))]
84 pub fn add_connection(
85 &mut self,
86 connection_id: ConnectionId,
87 user_id: UserId,
88 admin: bool,
89 zed_version: ZedVersion,
90 ) {
91 self.connections.insert(
92 connection_id,
93 Connection {
94 principal_id: PrincipalId::UserId(user_id),
95 admin,
96 zed_version,
97 },
98 );
99 let connected_user = self.connected_users.entry(user_id).or_default();
100 connected_user.connection_ids.insert(connection_id);
101 }
102
103 pub fn add_dev_server(
104 &mut self,
105 connection_id: ConnectionId,
106 dev_server_id: DevServerId,
107 zed_version: ZedVersion,
108 ) {
109 self.connections.insert(
110 connection_id,
111 Connection {
112 principal_id: PrincipalId::DevServerId(dev_server_id),
113 admin: false,
114 zed_version,
115 },
116 );
117
118 self.connected_dev_servers
119 .insert(dev_server_id, connection_id);
120 }
121
122 #[instrument(skip(self))]
123 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
124 let connection = self
125 .connections
126 .get_mut(&connection_id)
127 .ok_or_else(|| anyhow!("no such connection"))?;
128
129 match connection.principal_id {
130 PrincipalId::UserId(user_id) => {
131 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
132 connected_user.connection_ids.remove(&connection_id);
133 if connected_user.connection_ids.is_empty() {
134 self.connected_users.remove(&user_id);
135 self.channels.remove_user(&user_id);
136 }
137 }
138 PrincipalId::DevServerId(dev_server_id) => {
139 self.connected_dev_servers.remove(&dev_server_id);
140 self.offline_dev_servers.remove(&dev_server_id);
141 }
142 }
143 self.connections.remove(&connection_id).unwrap();
144 Ok(())
145 }
146
147 pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
148 self.offline_dev_servers.insert(dev_server_id);
149 }
150
151 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
152 self.connections.values()
153 }
154
155 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
156 self.connected_users
157 .get(&user_id)
158 .into_iter()
159 .flat_map(|state| {
160 state
161 .connection_ids
162 .iter()
163 .flat_map(|cid| self.connections.get(cid))
164 })
165 }
166
167 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
168 self.connected_users
169 .get(&user_id)
170 .into_iter()
171 .flat_map(|state| &state.connection_ids)
172 .copied()
173 }
174
175 pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
176 if self.dev_server_connection_id(dev_server_id).is_some()
177 && !self.offline_dev_servers.contains(&dev_server_id)
178 {
179 proto::DevServerStatus::Online
180 } else {
181 proto::DevServerStatus::Offline
182 }
183 }
184
185 pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
186 self.connected_dev_servers.get(&dev_server_id).copied()
187 }
188
189 pub fn channel_user_ids(
190 &self,
191 channel_id: ChannelId,
192 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
193 self.channels.users_to_notify(channel_id)
194 }
195
196 pub fn channel_connection_ids(
197 &self,
198 channel_id: ChannelId,
199 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
200 self.channels
201 .users_to_notify(channel_id)
202 .flat_map(|(user_id, role)| {
203 self.user_connection_ids(user_id)
204 .map(move |connection_id| (connection_id, role))
205 })
206 }
207
208 pub fn subscribe_to_channel(
209 &mut self,
210 user_id: UserId,
211 channel_id: ChannelId,
212 role: ChannelRole,
213 ) {
214 self.channels.subscribe(user_id, channel_id, role);
215 }
216
217 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
218 self.channels.unsubscribe(user_id, channel_id);
219 }
220
221 pub fn is_user_online(&self, user_id: UserId) -> bool {
222 !self
223 .connected_users
224 .get(&user_id)
225 .unwrap_or(&Default::default())
226 .connection_ids
227 .is_empty()
228 }
229
230 #[cfg(test)]
231 pub fn check_invariants(&self) {
232 for (connection_id, connection) in &self.connections {
233 match &connection.principal_id {
234 PrincipalId::UserId(user_id) => {
235 assert!(self
236 .connected_users
237 .get(user_id)
238 .unwrap()
239 .connection_ids
240 .contains(connection_id));
241 }
242 PrincipalId::DevServerId(dev_server_id) => {
243 assert_eq!(
244 self.connected_dev_servers.get(&dev_server_id).unwrap(),
245 connection_id
246 );
247 }
248 }
249 }
250
251 for (user_id, state) in &self.connected_users {
252 for connection_id in &state.connection_ids {
253 assert_eq!(
254 self.connections.get(connection_id).unwrap().principal_id,
255 PrincipalId::UserId(*user_id)
256 );
257 }
258 }
259
260 for (dev_server_id, connection_id) in &self.connected_dev_servers {
261 assert_eq!(
262 self.connections.get(connection_id).unwrap().principal_id,
263 PrincipalId::DevServerId(*dev_server_id)
264 );
265 }
266 }
267}
268
269#[derive(Default, Serialize)]
270pub struct ChannelPool {
271 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
272 by_channel: HashMap<ChannelId, HashSet<UserId>>,
273}
274
275impl ChannelPool {
276 pub fn clear(&mut self) {
277 self.by_user.clear();
278 self.by_channel.clear();
279 }
280
281 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
282 self.by_user
283 .entry(user_id)
284 .or_default()
285 .insert(channel_id, role);
286 self.by_channel
287 .entry(channel_id)
288 .or_default()
289 .insert(user_id);
290 }
291
292 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
293 if let Some(channels) = self.by_user.get_mut(user_id) {
294 channels.remove(channel_id);
295 if channels.is_empty() {
296 self.by_user.remove(user_id);
297 }
298 }
299 if let Some(users) = self.by_channel.get_mut(channel_id) {
300 users.remove(user_id);
301 if users.is_empty() {
302 self.by_channel.remove(channel_id);
303 }
304 }
305 }
306
307 pub fn remove_user(&mut self, user_id: &UserId) {
308 if let Some(channels) = self.by_user.remove(&user_id) {
309 for channel_id in channels.keys() {
310 self.unsubscribe(user_id, &channel_id)
311 }
312 }
313 }
314
315 pub fn users_to_notify(
316 &self,
317 channel_id: ChannelId,
318 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
319 self.by_channel
320 .get(&channel_id)
321 .into_iter()
322 .flat_map(move |users| {
323 users.iter().flat_map(move |user_id| {
324 Some((
325 *user_id,
326 self.by_user
327 .get(user_id)
328 .and_then(|channels| channels.get(&channel_id))
329 .copied()?,
330 ))
331 })
332 })
333 }
334}