1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use semantic_version::SemanticVersion;
6use serde::Serialize;
7use std::fmt;
8use tracing::instrument;
9
10#[derive(Default, Serialize)]
11pub struct ConnectionPool {
12 connections: BTreeMap<ConnectionId, Connection>,
13 connected_users: BTreeMap<UserId, ConnectedPrincipal>,
14 connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
15 channels: ChannelPool,
16 offline_dev_servers: HashSet<DevServerId>,
17}
18
19#[derive(Default, Serialize)]
20struct ConnectedPrincipal {
21 connection_ids: HashSet<ConnectionId>,
22}
23
24#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
25pub struct ZedVersion(pub SemanticVersion);
26
27impl fmt::Display for ZedVersion {
28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29 write!(f, "{}", self.0)
30 }
31}
32
33impl ZedVersion {
34 pub fn can_collaborate(&self) -> bool {
35 self.0 >= SemanticVersion::new(0, 129, 2)
36 }
37
38 pub fn with_save_as() -> ZedVersion {
39 ZedVersion(SemanticVersion::new(0, 134, 0))
40 }
41
42 pub fn with_list_directory() -> ZedVersion {
43 ZedVersion(SemanticVersion::new(0, 145, 0))
44 }
45
46 pub fn with_search_candidates() -> ZedVersion {
47 ZedVersion(SemanticVersion::new(0, 151, 0))
48 }
49}
50
51pub trait VersionedMessage {
52 fn required_host_version(&self) -> Option<ZedVersion> {
53 None
54 }
55}
56
57impl VersionedMessage for proto::SaveBuffer {
58 fn required_host_version(&self) -> Option<ZedVersion> {
59 if self.new_path.is_some() {
60 Some(ZedVersion::with_save_as())
61 } else {
62 None
63 }
64 }
65}
66
67impl VersionedMessage for proto::OpenNewBuffer {
68 fn required_host_version(&self) -> Option<ZedVersion> {
69 Some(ZedVersion::with_save_as())
70 }
71}
72
73#[derive(Serialize)]
74pub struct Connection {
75 pub principal_id: PrincipalId,
76 pub admin: bool,
77 pub zed_version: ZedVersion,
78}
79
80impl ConnectionPool {
81 pub fn reset(&mut self) {
82 self.connections.clear();
83 self.connected_users.clear();
84 self.connected_dev_servers.clear();
85 self.channels.clear();
86 }
87
88 pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
89 self.connections.get(&connection_id)
90 }
91
92 #[instrument(skip(self))]
93 pub fn add_connection(
94 &mut self,
95 connection_id: ConnectionId,
96 user_id: UserId,
97 admin: bool,
98 zed_version: ZedVersion,
99 ) {
100 self.connections.insert(
101 connection_id,
102 Connection {
103 principal_id: PrincipalId::UserId(user_id),
104 admin,
105 zed_version,
106 },
107 );
108 let connected_user = self.connected_users.entry(user_id).or_default();
109 connected_user.connection_ids.insert(connection_id);
110 }
111
112 pub fn add_dev_server(
113 &mut self,
114 connection_id: ConnectionId,
115 dev_server_id: DevServerId,
116 zed_version: ZedVersion,
117 ) {
118 self.connections.insert(
119 connection_id,
120 Connection {
121 principal_id: PrincipalId::DevServerId(dev_server_id),
122 admin: false,
123 zed_version,
124 },
125 );
126
127 self.connected_dev_servers
128 .insert(dev_server_id, connection_id);
129 }
130
131 #[instrument(skip(self))]
132 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
133 let connection = self
134 .connections
135 .get_mut(&connection_id)
136 .ok_or_else(|| anyhow!("no such connection"))?;
137
138 match connection.principal_id {
139 PrincipalId::UserId(user_id) => {
140 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
141 connected_user.connection_ids.remove(&connection_id);
142 if connected_user.connection_ids.is_empty() {
143 self.connected_users.remove(&user_id);
144 self.channels.remove_user(&user_id);
145 }
146 }
147 PrincipalId::DevServerId(dev_server_id) => {
148 self.connected_dev_servers.remove(&dev_server_id);
149 self.offline_dev_servers.remove(&dev_server_id);
150 }
151 }
152 self.connections.remove(&connection_id).unwrap();
153 Ok(())
154 }
155
156 pub fn set_dev_server_offline(&mut self, dev_server_id: DevServerId) {
157 self.offline_dev_servers.insert(dev_server_id);
158 }
159
160 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
161 self.connections.values()
162 }
163
164 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
165 self.connected_users
166 .get(&user_id)
167 .into_iter()
168 .flat_map(|state| {
169 state
170 .connection_ids
171 .iter()
172 .flat_map(|cid| self.connections.get(cid))
173 })
174 }
175
176 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
177 self.connected_users
178 .get(&user_id)
179 .into_iter()
180 .flat_map(|state| &state.connection_ids)
181 .copied()
182 }
183
184 pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
185 if self.dev_server_connection_id(dev_server_id).is_some()
186 && !self.offline_dev_servers.contains(&dev_server_id)
187 {
188 proto::DevServerStatus::Online
189 } else {
190 proto::DevServerStatus::Offline
191 }
192 }
193
194 pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
195 self.connected_dev_servers.get(&dev_server_id).copied()
196 }
197
198 pub fn dev_server_connection_id_supporting(
199 &self,
200 dev_server_id: DevServerId,
201 required: ZedVersion,
202 ) -> Result<ConnectionId> {
203 match self.connected_dev_servers.get(&dev_server_id) {
204 Some(cid) if self.connections[cid].zed_version >= required => Ok(*cid),
205 Some(_) => Err(anyhow!(proto::ErrorCode::RemoteUpgradeRequired)),
206 None => Err(anyhow!(proto::ErrorCode::DevServerOffline)),
207 }
208 }
209
210 pub fn channel_user_ids(
211 &self,
212 channel_id: ChannelId,
213 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
214 self.channels.users_to_notify(channel_id)
215 }
216
217 pub fn channel_connection_ids(
218 &self,
219 channel_id: ChannelId,
220 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
221 self.channels
222 .users_to_notify(channel_id)
223 .flat_map(|(user_id, role)| {
224 self.user_connection_ids(user_id)
225 .map(move |connection_id| (connection_id, role))
226 })
227 }
228
229 pub fn subscribe_to_channel(
230 &mut self,
231 user_id: UserId,
232 channel_id: ChannelId,
233 role: ChannelRole,
234 ) {
235 self.channels.subscribe(user_id, channel_id, role);
236 }
237
238 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
239 self.channels.unsubscribe(user_id, channel_id);
240 }
241
242 pub fn is_user_online(&self, user_id: UserId) -> bool {
243 !self
244 .connected_users
245 .get(&user_id)
246 .unwrap_or(&Default::default())
247 .connection_ids
248 .is_empty()
249 }
250
251 #[cfg(test)]
252 pub fn check_invariants(&self) {
253 for (connection_id, connection) in &self.connections {
254 match &connection.principal_id {
255 PrincipalId::UserId(user_id) => {
256 assert!(self
257 .connected_users
258 .get(user_id)
259 .unwrap()
260 .connection_ids
261 .contains(connection_id));
262 }
263 PrincipalId::DevServerId(dev_server_id) => {
264 assert_eq!(
265 self.connected_dev_servers.get(&dev_server_id).unwrap(),
266 connection_id
267 );
268 }
269 }
270 }
271
272 for (user_id, state) in &self.connected_users {
273 for connection_id in &state.connection_ids {
274 assert_eq!(
275 self.connections.get(connection_id).unwrap().principal_id,
276 PrincipalId::UserId(*user_id)
277 );
278 }
279 }
280
281 for (dev_server_id, connection_id) in &self.connected_dev_servers {
282 assert_eq!(
283 self.connections.get(connection_id).unwrap().principal_id,
284 PrincipalId::DevServerId(*dev_server_id)
285 );
286 }
287 }
288}
289
290#[derive(Default, Serialize)]
291pub struct ChannelPool {
292 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
293 by_channel: HashMap<ChannelId, HashSet<UserId>>,
294}
295
296impl ChannelPool {
297 pub fn clear(&mut self) {
298 self.by_user.clear();
299 self.by_channel.clear();
300 }
301
302 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
303 self.by_user
304 .entry(user_id)
305 .or_default()
306 .insert(channel_id, role);
307 self.by_channel
308 .entry(channel_id)
309 .or_default()
310 .insert(user_id);
311 }
312
313 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
314 if let Some(channels) = self.by_user.get_mut(user_id) {
315 channels.remove(channel_id);
316 if channels.is_empty() {
317 self.by_user.remove(user_id);
318 }
319 }
320 if let Some(users) = self.by_channel.get_mut(channel_id) {
321 users.remove(user_id);
322 if users.is_empty() {
323 self.by_channel.remove(channel_id);
324 }
325 }
326 }
327
328 pub fn remove_user(&mut self, user_id: &UserId) {
329 if let Some(channels) = self.by_user.remove(&user_id) {
330 for channel_id in channels.keys() {
331 self.unsubscribe(user_id, &channel_id)
332 }
333 }
334 }
335
336 pub fn users_to_notify(
337 &self,
338 channel_id: ChannelId,
339 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
340 self.by_channel
341 .get(&channel_id)
342 .into_iter()
343 .flat_map(move |users| {
344 users.iter().flat_map(move |user_id| {
345 Some((
346 *user_id,
347 self.by_user
348 .get(user_id)
349 .and_then(|channels| channels.get(&channel_id))
350 .copied()?,
351 ))
352 })
353 })
354 }
355}