1use crate::db::{ChannelId, ChannelRole, UserId};
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::ConnectionId;
5use serde::Serialize;
6use tracing::instrument;
7use util::{semver, SemanticVersion};
8
9#[derive(Default, Serialize)]
10pub struct ConnectionPool {
11 connections: BTreeMap<ConnectionId, Connection>,
12 connected_users: BTreeMap<UserId, ConnectedUser>,
13 channels: ChannelPool,
14}
15
16#[derive(Default, Serialize)]
17struct ConnectedUser {
18 connection_ids: HashSet<ConnectionId>,
19}
20
21#[derive(Debug, Serialize)]
22pub struct ZedVersion(pub SemanticVersion);
23use std::fmt;
24
25impl fmt::Display for ZedVersion {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 write!(f, "{}", self.0)
28 }
29}
30
31impl ZedVersion {
32 pub fn can_collaborate(&self) -> bool {
33 self.0 >= semver(0, 127, 3) || (self.0 >= semver(0, 126, 3) && self.0 < semver(0, 127, 0))
34 }
35}
36
37#[derive(Serialize)]
38pub struct Connection {
39 pub user_id: UserId,
40 pub admin: bool,
41 pub zed_version: ZedVersion,
42}
43
44impl ConnectionPool {
45 pub fn reset(&mut self) {
46 self.connections.clear();
47 self.connected_users.clear();
48 self.channels.clear();
49 }
50
51 #[instrument(skip(self))]
52 pub fn add_connection(
53 &mut self,
54 connection_id: ConnectionId,
55 user_id: UserId,
56 admin: bool,
57 zed_version: ZedVersion,
58 ) {
59 self.connections.insert(
60 connection_id,
61 Connection {
62 user_id,
63 admin,
64 zed_version,
65 },
66 );
67 let connected_user = self.connected_users.entry(user_id).or_default();
68 connected_user.connection_ids.insert(connection_id);
69 }
70
71 #[instrument(skip(self))]
72 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
73 let connection = self
74 .connections
75 .get_mut(&connection_id)
76 .ok_or_else(|| anyhow!("no such connection"))?;
77
78 let user_id = connection.user_id;
79 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
80 connected_user.connection_ids.remove(&connection_id);
81 if connected_user.connection_ids.is_empty() {
82 self.connected_users.remove(&user_id);
83 self.channels.remove_user(&user_id);
84 }
85 self.connections.remove(&connection_id).unwrap();
86 Ok(())
87 }
88
89 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
90 self.connections.values()
91 }
92
93 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
94 self.connected_users
95 .get(&user_id)
96 .into_iter()
97 .flat_map(|state| {
98 state
99 .connection_ids
100 .iter()
101 .flat_map(|cid| self.connections.get(cid))
102 })
103 }
104
105 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
106 self.connected_users
107 .get(&user_id)
108 .into_iter()
109 .flat_map(|state| &state.connection_ids)
110 .copied()
111 }
112
113 pub fn channel_user_ids(
114 &self,
115 channel_id: ChannelId,
116 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
117 self.channels.users_to_notify(channel_id)
118 }
119
120 pub fn channel_connection_ids(
121 &self,
122 channel_id: ChannelId,
123 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
124 self.channels
125 .users_to_notify(channel_id)
126 .flat_map(|(user_id, role)| {
127 self.user_connection_ids(user_id)
128 .map(move |connection_id| (connection_id, role))
129 })
130 }
131
132 pub fn subscribe_to_channel(
133 &mut self,
134 user_id: UserId,
135 channel_id: ChannelId,
136 role: ChannelRole,
137 ) {
138 self.channels.subscribe(user_id, channel_id, role);
139 }
140
141 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
142 self.channels.unsubscribe(user_id, channel_id);
143 }
144
145 pub fn is_user_online(&self, user_id: UserId) -> bool {
146 !self
147 .connected_users
148 .get(&user_id)
149 .unwrap_or(&Default::default())
150 .connection_ids
151 .is_empty()
152 }
153
154 #[cfg(test)]
155 pub fn check_invariants(&self) {
156 for (connection_id, connection) in &self.connections {
157 assert!(self
158 .connected_users
159 .get(&connection.user_id)
160 .unwrap()
161 .connection_ids
162 .contains(connection_id));
163 }
164
165 for (user_id, state) in &self.connected_users {
166 for connection_id in &state.connection_ids {
167 assert_eq!(
168 self.connections.get(connection_id).unwrap().user_id,
169 *user_id
170 );
171 }
172 }
173 }
174}
175
176#[derive(Default, Serialize)]
177pub struct ChannelPool {
178 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
179 by_channel: HashMap<ChannelId, HashSet<UserId>>,
180}
181
182impl ChannelPool {
183 pub fn clear(&mut self) {
184 self.by_user.clear();
185 self.by_channel.clear();
186 }
187
188 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
189 self.by_user
190 .entry(user_id)
191 .or_default()
192 .insert(channel_id, role);
193 self.by_channel
194 .entry(channel_id)
195 .or_default()
196 .insert(user_id);
197 }
198
199 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
200 if let Some(channels) = self.by_user.get_mut(user_id) {
201 channels.remove(channel_id);
202 if channels.is_empty() {
203 self.by_user.remove(user_id);
204 }
205 }
206 if let Some(users) = self.by_channel.get_mut(channel_id) {
207 users.remove(user_id);
208 if users.is_empty() {
209 self.by_channel.remove(channel_id);
210 }
211 }
212 }
213
214 pub fn remove_user(&mut self, user_id: &UserId) {
215 if let Some(channels) = self.by_user.remove(&user_id) {
216 for channel_id in channels.keys() {
217 self.unsubscribe(user_id, &channel_id)
218 }
219 }
220 }
221
222 pub fn users_to_notify(
223 &self,
224 channel_id: ChannelId,
225 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
226 self.by_channel
227 .get(&channel_id)
228 .into_iter()
229 .flat_map(move |users| {
230 users.iter().flat_map(move |user_id| {
231 Some((
232 *user_id,
233 self.by_user
234 .get(user_id)
235 .and_then(|channels| channels.get(&channel_id))
236 .copied()?,
237 ))
238 })
239 })
240 }
241}