1use crate::db::{ChannelId, ChannelRole, UserId};
2use anyhow::{Context as _, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::ConnectionId;
5use semantic_version::SemanticVersion;
6use serde::Serialize;
7use std::fmt;
8use tracing::instrument;
9
10#[derive(Default, Serialize)]
11pub struct ConnectionPool {
12 connections: BTreeMap<ConnectionId, Connection>,
13 connected_users: BTreeMap<UserId, ConnectedPrincipal>,
14 channels: ChannelPool,
15}
16
17#[derive(Default, Serialize)]
18struct ConnectedPrincipal {
19 connection_ids: HashSet<ConnectionId>,
20}
21
22#[derive(Copy, Clone, Debug, Serialize, PartialOrd, PartialEq, Eq, Ord)]
23pub struct ZedVersion(pub SemanticVersion);
24
25impl fmt::Display for ZedVersion {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 write!(f, "{}", self.0)
28 }
29}
30
31impl ZedVersion {
32 pub fn can_collaborate(&self) -> bool {
33 // v0.198.4 is the first version where we no longer connect to Collab automatically.
34 // We reject any clients older than that to prevent them from connecting to Collab just for authentication.
35 if self.0 < SemanticVersion::new(0, 198, 4) {
36 return false;
37 }
38
39 // Since we hotfixed the changes to no longer connect to Collab automatically to Preview, we also need to reject
40 // versions in the range [v0.199.0, v0.199.1].
41 if self.0 >= SemanticVersion::new(0, 199, 0) && self.0 < SemanticVersion::new(0, 199, 2) {
42 return false;
43 }
44
45 true
46 }
47}
48
49#[derive(Serialize)]
50pub struct Connection {
51 pub user_id: UserId,
52 pub admin: bool,
53 pub zed_version: ZedVersion,
54}
55
56impl ConnectionPool {
57 pub fn reset(&mut self) {
58 self.connections.clear();
59 self.connected_users.clear();
60 self.channels.clear();
61 }
62
63 pub fn connection(&mut self, connection_id: ConnectionId) -> Option<&Connection> {
64 self.connections.get(&connection_id)
65 }
66
67 #[instrument(skip(self))]
68 pub fn add_connection(
69 &mut self,
70 connection_id: ConnectionId,
71 user_id: UserId,
72 admin: bool,
73 zed_version: ZedVersion,
74 ) {
75 self.connections.insert(
76 connection_id,
77 Connection {
78 user_id,
79 admin,
80 zed_version,
81 },
82 );
83 let connected_user = self.connected_users.entry(user_id).or_default();
84 connected_user.connection_ids.insert(connection_id);
85 }
86
87 #[instrument(skip(self))]
88 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
89 let connection = self
90 .connections
91 .get_mut(&connection_id)
92 .context("no such connection")?;
93
94 let user_id = connection.user_id;
95
96 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
97 connected_user.connection_ids.remove(&connection_id);
98 if connected_user.connection_ids.is_empty() {
99 self.connected_users.remove(&user_id);
100 self.channels.remove_user(&user_id);
101 };
102 self.connections.remove(&connection_id).unwrap();
103 Ok(())
104 }
105
106 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
107 self.connections.values()
108 }
109
110 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
111 self.connected_users
112 .get(&user_id)
113 .into_iter()
114 .flat_map(|state| {
115 state
116 .connection_ids
117 .iter()
118 .flat_map(|cid| self.connections.get(cid))
119 })
120 }
121
122 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
123 self.connected_users
124 .get(&user_id)
125 .into_iter()
126 .flat_map(|state| &state.connection_ids)
127 .copied()
128 }
129
130 pub fn channel_user_ids(
131 &self,
132 channel_id: ChannelId,
133 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
134 self.channels.users_to_notify(channel_id)
135 }
136
137 pub fn channel_connection_ids(
138 &self,
139 channel_id: ChannelId,
140 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
141 self.channels
142 .users_to_notify(channel_id)
143 .flat_map(|(user_id, role)| {
144 self.user_connection_ids(user_id)
145 .map(move |connection_id| (connection_id, role))
146 })
147 }
148
149 pub fn subscribe_to_channel(
150 &mut self,
151 user_id: UserId,
152 channel_id: ChannelId,
153 role: ChannelRole,
154 ) {
155 self.channels.subscribe(user_id, channel_id, role);
156 }
157
158 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
159 self.channels.unsubscribe(user_id, channel_id);
160 }
161
162 pub fn is_user_online(&self, user_id: UserId) -> bool {
163 !self
164 .connected_users
165 .get(&user_id)
166 .unwrap_or(&Default::default())
167 .connection_ids
168 .is_empty()
169 }
170
171 #[cfg(test)]
172 pub fn check_invariants(&self) {
173 for (connection_id, connection) in &self.connections {
174 assert!(
175 self.connected_users
176 .get(&connection.user_id)
177 .unwrap()
178 .connection_ids
179 .contains(connection_id)
180 );
181 }
182
183 for (user_id, state) in &self.connected_users {
184 for connection_id in &state.connection_ids {
185 assert_eq!(
186 self.connections.get(connection_id).unwrap().user_id,
187 *user_id
188 );
189 }
190 }
191 }
192}
193
194#[derive(Default, Serialize)]
195pub struct ChannelPool {
196 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
197 by_channel: HashMap<ChannelId, HashSet<UserId>>,
198}
199
200impl ChannelPool {
201 pub fn clear(&mut self) {
202 self.by_user.clear();
203 self.by_channel.clear();
204 }
205
206 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
207 self.by_user
208 .entry(user_id)
209 .or_default()
210 .insert(channel_id, role);
211 self.by_channel
212 .entry(channel_id)
213 .or_default()
214 .insert(user_id);
215 }
216
217 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
218 if let Some(channels) = self.by_user.get_mut(user_id) {
219 channels.remove(channel_id);
220 if channels.is_empty() {
221 self.by_user.remove(user_id);
222 }
223 }
224 if let Some(users) = self.by_channel.get_mut(channel_id) {
225 users.remove(user_id);
226 if users.is_empty() {
227 self.by_channel.remove(channel_id);
228 }
229 }
230 }
231
232 pub fn remove_user(&mut self, user_id: &UserId) {
233 if let Some(channels) = self.by_user.remove(user_id) {
234 for channel_id in channels.keys() {
235 self.unsubscribe(user_id, channel_id)
236 }
237 }
238 }
239
240 pub fn users_to_notify(
241 &self,
242 channel_id: ChannelId,
243 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
244 self.by_channel
245 .get(&channel_id)
246 .into_iter()
247 .flat_map(move |users| {
248 users.iter().flat_map(move |user_id| {
249 Some((
250 *user_id,
251 self.by_user
252 .get(user_id)
253 .and_then(|channels| channels.get(&channel_id))
254 .copied()?,
255 ))
256 })
257 })
258 }
259}