1use crate::db::{ChannelId, ChannelRole, DevServerId, PrincipalId, UserId};
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::{proto, ConnectionId};
5use semantic_version::SemanticVersion;
6use serde::Serialize;
7use std::fmt;
8use tracing::instrument;
9
10#[derive(Default, Serialize)]
11pub struct ConnectionPool {
12 connections: BTreeMap<ConnectionId, Connection>,
13 connected_users: BTreeMap<UserId, ConnectedPrincipal>,
14 connected_dev_servers: BTreeMap<DevServerId, ConnectionId>,
15 channels: ChannelPool,
16}
17
18#[derive(Default, Serialize)]
19struct ConnectedPrincipal {
20 connection_ids: HashSet<ConnectionId>,
21}
22
23#[derive(Debug, Serialize)]
24pub struct ZedVersion(pub SemanticVersion);
25
26impl fmt::Display for ZedVersion {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 write!(f, "{}", self.0)
29 }
30}
31
32impl ZedVersion {
33 pub fn can_collaborate(&self) -> bool {
34 self.0 >= SemanticVersion::new(0, 129, 2)
35 }
36}
37
38#[derive(Serialize)]
39pub struct Connection {
40 pub principal_id: PrincipalId,
41 pub admin: bool,
42 pub zed_version: ZedVersion,
43}
44
45impl ConnectionPool {
46 pub fn reset(&mut self) {
47 self.connections.clear();
48 self.connected_users.clear();
49 self.channels.clear();
50 }
51
52 #[instrument(skip(self))]
53 pub fn add_connection(
54 &mut self,
55 connection_id: ConnectionId,
56 user_id: UserId,
57 admin: bool,
58 zed_version: ZedVersion,
59 ) {
60 self.connections.insert(
61 connection_id,
62 Connection {
63 principal_id: PrincipalId::UserId(user_id),
64 admin,
65 zed_version,
66 },
67 );
68 let connected_user = self.connected_users.entry(user_id).or_default();
69 connected_user.connection_ids.insert(connection_id);
70 }
71
72 pub fn add_dev_server(
73 &mut self,
74 connection_id: ConnectionId,
75 dev_server_id: DevServerId,
76 zed_version: ZedVersion,
77 ) {
78 self.connections.insert(
79 connection_id,
80 Connection {
81 principal_id: PrincipalId::DevServerId(dev_server_id),
82 admin: false,
83 zed_version,
84 },
85 );
86
87 self.connected_dev_servers
88 .insert(dev_server_id, connection_id);
89 }
90
91 #[instrument(skip(self))]
92 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
93 let connection = self
94 .connections
95 .get_mut(&connection_id)
96 .ok_or_else(|| anyhow!("no such connection"))?;
97
98 match connection.principal_id {
99 PrincipalId::UserId(user_id) => {
100 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
101 connected_user.connection_ids.remove(&connection_id);
102 if connected_user.connection_ids.is_empty() {
103 self.connected_users.remove(&user_id);
104 self.channels.remove_user(&user_id);
105 }
106 }
107 PrincipalId::DevServerId(dev_server_id) => {
108 self.connected_dev_servers.remove(&dev_server_id);
109 }
110 }
111 self.connections.remove(&connection_id).unwrap();
112 Ok(())
113 }
114
115 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
116 self.connections.values()
117 }
118
119 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
120 self.connected_users
121 .get(&user_id)
122 .into_iter()
123 .flat_map(|state| {
124 state
125 .connection_ids
126 .iter()
127 .flat_map(|cid| self.connections.get(cid))
128 })
129 }
130
131 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
132 self.connected_users
133 .get(&user_id)
134 .into_iter()
135 .flat_map(|state| &state.connection_ids)
136 .copied()
137 }
138
139 pub fn dev_server_status(&self, dev_server_id: DevServerId) -> proto::DevServerStatus {
140 if self.dev_server_connection_id(dev_server_id).is_some() {
141 proto::DevServerStatus::Online
142 } else {
143 proto::DevServerStatus::Offline
144 }
145 }
146
147 pub fn dev_server_connection_id(&self, dev_server_id: DevServerId) -> Option<ConnectionId> {
148 self.connected_dev_servers.get(&dev_server_id).copied()
149 }
150
151 pub fn channel_user_ids(
152 &self,
153 channel_id: ChannelId,
154 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
155 self.channels.users_to_notify(channel_id)
156 }
157
158 pub fn channel_connection_ids(
159 &self,
160 channel_id: ChannelId,
161 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
162 self.channels
163 .users_to_notify(channel_id)
164 .flat_map(|(user_id, role)| {
165 self.user_connection_ids(user_id)
166 .map(move |connection_id| (connection_id, role))
167 })
168 }
169
170 pub fn subscribe_to_channel(
171 &mut self,
172 user_id: UserId,
173 channel_id: ChannelId,
174 role: ChannelRole,
175 ) {
176 self.channels.subscribe(user_id, channel_id, role);
177 }
178
179 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
180 self.channels.unsubscribe(user_id, channel_id);
181 }
182
183 pub fn is_user_online(&self, user_id: UserId) -> bool {
184 !self
185 .connected_users
186 .get(&user_id)
187 .unwrap_or(&Default::default())
188 .connection_ids
189 .is_empty()
190 }
191
192 #[cfg(test)]
193 pub fn check_invariants(&self) {
194 for (connection_id, connection) in &self.connections {
195 match &connection.principal_id {
196 PrincipalId::UserId(user_id) => {
197 assert!(self
198 .connected_users
199 .get(user_id)
200 .unwrap()
201 .connection_ids
202 .contains(connection_id));
203 }
204 PrincipalId::DevServerId(dev_server_id) => {
205 assert_eq!(
206 self.connected_dev_servers.get(&dev_server_id).unwrap(),
207 connection_id
208 );
209 }
210 }
211 }
212
213 for (user_id, state) in &self.connected_users {
214 for connection_id in &state.connection_ids {
215 assert_eq!(
216 self.connections.get(connection_id).unwrap().principal_id,
217 PrincipalId::UserId(*user_id)
218 );
219 }
220 }
221
222 for (dev_server_id, connection_id) in &self.connected_dev_servers {
223 assert_eq!(
224 self.connections.get(connection_id).unwrap().principal_id,
225 PrincipalId::DevServerId(*dev_server_id)
226 );
227 }
228 }
229}
230
231#[derive(Default, Serialize)]
232pub struct ChannelPool {
233 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
234 by_channel: HashMap<ChannelId, HashSet<UserId>>,
235}
236
237impl ChannelPool {
238 pub fn clear(&mut self) {
239 self.by_user.clear();
240 self.by_channel.clear();
241 }
242
243 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
244 self.by_user
245 .entry(user_id)
246 .or_default()
247 .insert(channel_id, role);
248 self.by_channel
249 .entry(channel_id)
250 .or_default()
251 .insert(user_id);
252 }
253
254 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
255 if let Some(channels) = self.by_user.get_mut(user_id) {
256 channels.remove(channel_id);
257 if channels.is_empty() {
258 self.by_user.remove(user_id);
259 }
260 }
261 if let Some(users) = self.by_channel.get_mut(channel_id) {
262 users.remove(user_id);
263 if users.is_empty() {
264 self.by_channel.remove(channel_id);
265 }
266 }
267 }
268
269 pub fn remove_user(&mut self, user_id: &UserId) {
270 if let Some(channels) = self.by_user.remove(&user_id) {
271 for channel_id in channels.keys() {
272 self.unsubscribe(user_id, &channel_id)
273 }
274 }
275 }
276
277 pub fn users_to_notify(
278 &self,
279 channel_id: ChannelId,
280 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
281 self.by_channel
282 .get(&channel_id)
283 .into_iter()
284 .flat_map(move |users| {
285 users.iter().flat_map(move |user_id| {
286 Some((
287 *user_id,
288 self.by_user
289 .get(user_id)
290 .and_then(|channels| channels.get(&channel_id))
291 .copied()?,
292 ))
293 })
294 })
295 }
296}