1use crate::db::{ChannelId, ChannelRole, UserId};
2use anyhow::{anyhow, Result};
3use collections::{BTreeMap, HashMap, HashSet};
4use rpc::ConnectionId;
5use serde::Serialize;
6use tracing::instrument;
7use util::SemanticVersion;
8
9#[derive(Default, Serialize)]
10pub struct ConnectionPool {
11 connections: BTreeMap<ConnectionId, Connection>,
12 connected_users: BTreeMap<UserId, ConnectedUser>,
13 channels: ChannelPool,
14}
15
16#[derive(Default, Serialize)]
17struct ConnectedUser {
18 connection_ids: HashSet<ConnectionId>,
19}
20
21#[derive(Debug, Serialize)]
22pub struct ZedVersion(pub SemanticVersion);
23use std::fmt;
24
25impl fmt::Display for ZedVersion {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 write!(f, "{}", self.0)
28 }
29}
30
31impl ZedVersion {
32 pub fn is_supported(&self) -> bool {
33 self.0 != SemanticVersion::new(0, 123, 0)
34 }
35 pub fn supports_talker_role(&self) -> bool {
36 self.0 >= SemanticVersion::new(0, 125, 0)
37 }
38}
39
40#[derive(Serialize)]
41pub struct Connection {
42 pub user_id: UserId,
43 pub admin: bool,
44 pub zed_version: ZedVersion,
45}
46
47impl ConnectionPool {
48 pub fn reset(&mut self) {
49 self.connections.clear();
50 self.connected_users.clear();
51 self.channels.clear();
52 }
53
54 #[instrument(skip(self))]
55 pub fn add_connection(
56 &mut self,
57 connection_id: ConnectionId,
58 user_id: UserId,
59 admin: bool,
60 zed_version: ZedVersion,
61 ) {
62 self.connections.insert(
63 connection_id,
64 Connection {
65 user_id,
66 admin,
67 zed_version,
68 },
69 );
70 let connected_user = self.connected_users.entry(user_id).or_default();
71 connected_user.connection_ids.insert(connection_id);
72 }
73
74 #[instrument(skip(self))]
75 pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> {
76 let connection = self
77 .connections
78 .get_mut(&connection_id)
79 .ok_or_else(|| anyhow!("no such connection"))?;
80
81 let user_id = connection.user_id;
82 let connected_user = self.connected_users.get_mut(&user_id).unwrap();
83 connected_user.connection_ids.remove(&connection_id);
84 if connected_user.connection_ids.is_empty() {
85 self.connected_users.remove(&user_id);
86 self.channels.remove_user(&user_id);
87 }
88 self.connections.remove(&connection_id).unwrap();
89 Ok(())
90 }
91
92 pub fn connections(&self) -> impl Iterator<Item = &Connection> {
93 self.connections.values()
94 }
95
96 pub fn user_connections(&self, user_id: UserId) -> impl Iterator<Item = &Connection> + '_ {
97 self.connected_users
98 .get(&user_id)
99 .into_iter()
100 .flat_map(|state| {
101 state
102 .connection_ids
103 .iter()
104 .flat_map(|cid| self.connections.get(cid))
105 })
106 }
107
108 pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator<Item = ConnectionId> + '_ {
109 self.connected_users
110 .get(&user_id)
111 .into_iter()
112 .flat_map(|state| &state.connection_ids)
113 .copied()
114 }
115
116 pub fn channel_user_ids(
117 &self,
118 channel_id: ChannelId,
119 ) -> impl Iterator<Item = (UserId, ChannelRole)> + '_ {
120 self.channels.users_to_notify(channel_id)
121 }
122
123 pub fn channel_connection_ids(
124 &self,
125 channel_id: ChannelId,
126 ) -> impl Iterator<Item = (ConnectionId, ChannelRole)> + '_ {
127 self.channels
128 .users_to_notify(channel_id)
129 .flat_map(|(user_id, role)| {
130 self.user_connection_ids(user_id)
131 .map(move |connection_id| (connection_id, role))
132 })
133 }
134
135 pub fn subscribe_to_channel(
136 &mut self,
137 user_id: UserId,
138 channel_id: ChannelId,
139 role: ChannelRole,
140 ) {
141 self.channels.subscribe(user_id, channel_id, role);
142 }
143
144 pub fn unsubscribe_from_channel(&mut self, user_id: &UserId, channel_id: &ChannelId) {
145 self.channels.unsubscribe(user_id, channel_id);
146 }
147
148 pub fn is_user_online(&self, user_id: UserId) -> bool {
149 !self
150 .connected_users
151 .get(&user_id)
152 .unwrap_or(&Default::default())
153 .connection_ids
154 .is_empty()
155 }
156
157 #[cfg(test)]
158 pub fn check_invariants(&self) {
159 for (connection_id, connection) in &self.connections {
160 assert!(self
161 .connected_users
162 .get(&connection.user_id)
163 .unwrap()
164 .connection_ids
165 .contains(connection_id));
166 }
167
168 for (user_id, state) in &self.connected_users {
169 for connection_id in &state.connection_ids {
170 assert_eq!(
171 self.connections.get(connection_id).unwrap().user_id,
172 *user_id
173 );
174 }
175 }
176 }
177}
178
179#[derive(Default, Serialize)]
180pub struct ChannelPool {
181 by_user: HashMap<UserId, HashMap<ChannelId, ChannelRole>>,
182 by_channel: HashMap<ChannelId, HashSet<UserId>>,
183}
184
185impl ChannelPool {
186 pub fn clear(&mut self) {
187 self.by_user.clear();
188 self.by_channel.clear();
189 }
190
191 pub fn subscribe(&mut self, user_id: UserId, channel_id: ChannelId, role: ChannelRole) {
192 self.by_user
193 .entry(user_id)
194 .or_default()
195 .insert(channel_id, role);
196 self.by_channel
197 .entry(channel_id)
198 .or_default()
199 .insert(user_id);
200 }
201
202 pub fn unsubscribe(&mut self, user_id: &UserId, channel_id: &ChannelId) {
203 if let Some(channels) = self.by_user.get_mut(user_id) {
204 channels.remove(channel_id);
205 if channels.is_empty() {
206 self.by_user.remove(user_id);
207 }
208 }
209 if let Some(users) = self.by_channel.get_mut(channel_id) {
210 users.remove(user_id);
211 if users.is_empty() {
212 self.by_channel.remove(channel_id);
213 }
214 }
215 }
216
217 pub fn remove_user(&mut self, user_id: &UserId) {
218 if let Some(channels) = self.by_user.remove(&user_id) {
219 for channel_id in channels.keys() {
220 self.unsubscribe(user_id, &channel_id)
221 }
222 }
223 }
224
225 pub fn users_to_notify(
226 &self,
227 channel_id: ChannelId,
228 ) -> impl '_ + Iterator<Item = (UserId, ChannelRole)> {
229 self.by_channel
230 .get(&channel_id)
231 .into_iter()
232 .flat_map(move |users| {
233 users.iter().flat_map(move |user_id| {
234 Some((
235 *user_id,
236 self.by_user
237 .get(user_id)
238 .and_then(|channels| channels.get(&channel_id))
239 .copied()?,
240 ))
241 })
242 })
243 }
244}